tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From haic...@apache.org
Subject [incubator-tvm] branch master updated: [Relay][VM] Fix compilation of If-Elses (#5040)
Date Wed, 11 Mar 2020 19:26:36 GMT
This is an automated email from the ASF dual-hosted git repository.

haichen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2586b4d  [Relay][VM] Fix compilation of If-Elses (#5040)
2586b4d is described below

commit 2586b4d32501bec8ff9b7cea145376615ecb00c9
Author: Wei Chen <ipondering.weic@gmail.com>
AuthorDate: Thu Mar 12 03:26:28 2020 +0800

    [Relay][VM] Fix compilation of If-Elses (#5040)
---
 src/relay/backend/vm/compiler.cc |  8 +++++---
 tests/python/relay/test_vm.py    | 19 +++++++++++++++++++
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index fc52a8e..e3c8d12 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -366,7 +366,9 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)>
{
     this->Emit(Instruction::If(test_register, target_register, 0, 0));
     this->VisitExpr(if_node->true_branch);
 
-    size_t true_register = last_register_;
+    // It saves the result of If-Else expression.
+    auto merge_register = NewRegister();
+    Emit(Instruction::Move(last_register_, merge_register));
     Emit(Instruction::Goto(0));
 
     // Finally store how many instructions there are in the
@@ -378,7 +380,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)>
{
     size_t false_register = last_register_;
 
     // In else-branch, override the then-branch register
-    Emit(Instruction::Move(false_register, true_register));
+    Emit(Instruction::Move(false_register, merge_register));
     // Compute the total number of instructions
     // after generating false.
     auto after_false = this->instructions_.size();
@@ -397,7 +399,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)>
{
     // Patch the Goto.
     this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
 
-    this->last_register_ = true_register;
+    this->last_register_ = merge_register;
   }
 
   void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs)
{
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 02f1e5b..a8ac27a 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -142,6 +142,25 @@ def test_simple_if():
     # diff
     check_result([x_data, y_data], y_data, mod=mod)
 
+def test_multiple_ifs():
+    mod = tvm.IRModule({})
+    b = relay.var('b')
+    v0 = relay.var('v0')
+    v1 = relay.var('v1')
+    v2 = relay.var('v2')
+    v3 = relay.var('v3')
+    out = relay.Tuple([v2, v3])
+    out = relay.Let(v3, relay.If(b, v1, v0), out)
+    out = relay.Let(v2, relay.If(b, v0, v1), out)
+    out = relay.Let(v1, relay.Tuple([relay.const(1)]), out)
+    out = relay.Let(v0, relay.Tuple([relay.const(0)]), out)
+    fn = relay.Function([b], out)
+    mod['main'] = fn
+    ctx = tvm.runtime.ndarray.context('llvm', 0)
+    vm = relay.create_executor(ctx=ctx, mod=mod, kind='vm')
+    res = vmobj_to_list(vm.evaluate()(False))
+    assert(res == [1, 0])
+
 def test_simple_call():
     mod = tvm.IRModule({})
     sum_up = relay.GlobalVar('sum_up')


Mime
View raw message