tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [RELAY][VM] Add shape_of instruction (#5855)
Date Sun, 28 Jun 2020 17:06:02 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7312934  [RELAY][VM] Add shape_of instruction (#5855)
7312934 is described below

commit 731293462cd54b83566dfabf35daa36a75d56839
Author: Zhi <5145158+zhiics@users.noreply.github.com>
AuthorDate: Sun Jun 28 10:05:50 2020 -0700

    [RELAY][VM] Add shape_of instruction (#5855)
---
 include/tvm/runtime/vm.h                    |  12 ++++
 python/tvm/relay/op/__init__.py             |   1 +
 python/tvm/relay/op/vm/__init__.py          |  20 ++++++
 python/tvm/relay/op/vm/_ffi_api.py          |  20 ++++++
 python/tvm/relay/op/vm/vm.py                |  35 +++++++++
 python/tvm/relay/transform/memory_alloc.py  |   4 +-
 src/relay/backend/vm/compiler.cc            |  13 ++++
 src/relay/op/tensor/unary.cc                |  14 ----
 src/relay/op/type_relations.cc              |  15 ++++
 src/relay/op/type_relations.h               |  12 ++++
 src/relay/op/vm/vm.cc                       |  58 +++++++++++++++
 src/relay/transforms/fold_constant.cc       |   4 +-
 src/runtime/vm/executable.cc                |  10 +++
 src/runtime/vm/vm.cc                        |  31 ++++++++
 tests/python/relay/test_vm_serialization.py | 107 ++++++++++------------------
 15 files changed, 268 insertions(+), 88 deletions(-)

diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h
index b9ccbf9..0cce533 100644
--- a/include/tvm/runtime/vm.h
+++ b/include/tvm/runtime/vm.h
@@ -114,6 +114,7 @@ enum class Opcode {
   LoadConsti = 14U,
   Fatal = 15U,
   AllocStorage = 16U,
+  ShapeOf = 17U,
 };
 
 /*! \brief A single virtual machine instruction.
@@ -245,6 +246,9 @@ struct Instruction {
       /*! \brief The hint of the dtype. */
       DLDataType dtype_hint;
     } alloc_storage;
+    struct /* ShapeOf Operands */ {
+      RegName tensor;
+    } shape_of;
   };
 
   /*!
@@ -389,6 +393,14 @@ struct Instruction {
   static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
                                   RegName dst);
 
+  /*!
+   * \brief Get the shape of an input tensor.
+   * \param tensor The input tensor.
+   * \param dst The destination to store the shape of the given tensor.
+   * \return The shape of instruction.
+   */
+  static Instruction ShapeOf(RegName tensor, RegName dst);
+
   Instruction();
   Instruction(const Instruction& instr);
   Instruction& operator=(const Instruction& instr);
diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py
index ce0df95..a45d466 100644
--- a/python/tvm/relay/op/__init__.py
+++ b/python/tvm/relay/op/__init__.py
@@ -27,6 +27,7 @@ from .reduce import *
 from .tensor import *
 from .transform import *
 from .algorithm import *
+from .vm import *
 from . import nn
 from . import annotation
 from . import memory
diff --git a/python/tvm/relay/op/vm/__init__.py b/python/tvm/relay/op/vm/__init__.py
new file mode 100644
index 0000000..2ac1e57
--- /dev/null
+++ b/python/tvm/relay/op/vm/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""Dialect operators for Relay VM."""
+from __future__ import absolute_import as _abs
+from . import vm
diff --git a/python/tvm/relay/op/vm/_ffi_api.py b/python/tvm/relay/op/vm/_ffi_api.py
new file mode 100644
index 0000000..3eeeeb8
--- /dev/null
+++ b/python/tvm/relay/op/vm/_ffi_api.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for relay.op.vm"""
+import tvm._ffi
+
+tvm._ffi._init_api("relay.op.vm", __name__)
diff --git a/python/tvm/relay/op/vm/vm.py b/python/tvm/relay/op/vm/vm.py
new file mode 100644
index 0000000..680729d
--- /dev/null
+++ b/python/tvm/relay/op/vm/vm.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
+"""Dialect operators for Relay VM."""
+from . import _ffi_api
+
+
+def shape_of(expr):
+    """Invoke a function to get the shape of a tensor.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+        The expr used to evaluate its tensor shape.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The expression with the evaluated tensor shape.
+    """
+    return _ffi_api.shape_of(expr)
diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py
index 6c081cb..a7ba2a8 100644
--- a/python/tvm/relay/transform/memory_alloc.py
+++ b/python/tvm/relay/transform/memory_alloc.py
@@ -44,6 +44,7 @@ class ManifestAllocPass(ExprMutator):
     def __init__(self, target_host):
         self.invoke_tvm = op.memory.invoke_tvm_op
         self.shape_func = op.memory.shape_func
+        self.shape_of = op.vm.shape_of
         self.scopes = [ScopeBuilder()]
         self.target_host = target_host
         self.default_context = cpu(0)
@@ -53,9 +54,6 @@ class ManifestAllocPass(ExprMutator):
     def current_scope(self):
         return self.scopes[-1]
 
-    def shape_of(self, e):
-        return op.shape_of(e, self.compute_dtype)
-
     def visit_tuple(self, tup):
         scope = self.current_scope()
         new_fields = []
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 0b839a2..2151acf 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -283,6 +283,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)>
{
       case Opcode::Invoke:
       case Opcode::AllocClosure:
       case Opcode::AllocStorage:
+      case Opcode::ShapeOf:
       case Opcode::Move:
       case Opcode::InvokeClosure:
         last_register_ = instr.dst;
@@ -588,6 +589,18 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)>
{
                    auto outputs = Downcast<Tuple>(args[2]);
                    EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
                  })
+          .Match("vm.shape_of",
+                 [this](const Array<Expr>& args, const Attrs& attrs, const
Array<Type>& type_arg) {
+                   CHECK_EQ(args.size(), 1U);
+                   // Get the attributes.
+                   const auto* shape_of_attrs = attrs.as<ShapeOfAttrs>();
+                   CHECK(shape_of_attrs) << "Must be the shape_of attrs";
+                   CHECK_EQ(shape_of_attrs->dtype.bits(), 64)
+                       << "The dtype of shape of must be int64, but got"
+                       << DLDataType2String(shape_of_attrs->dtype);
+                   this->VisitExpr(args[0]);
+                   Emit(Instruction::ShapeOf(last_register_, NewRegister()));
+                 })
           .Match("memory.kill",
                  [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>&
type_arg) {
                    LOG(FATAL) << "memory.kill is not yet supported";
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index 6b72670..99e6c02 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -396,20 +396,6 @@ RELAY_REGISTER_UNARY_OP("bitwise_not")
 // shape_of
 TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
 
-bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
-                const TypeReporter& reporter) {
-  CHECK_EQ(num_inputs, 1);
-  auto tt = types[0].as<TensorTypeNode>();
-  if (tt == nullptr) {
-    return false;
-  }
-  const auto* param = attrs.as<ShapeOfAttrs>();
-  CHECK(param != nullptr);
-  auto rank_shape = RankShape(tt->shape);
-  reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
-  return true;
-}
-
 Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>&
inputs,
                                  const Type& out_type) {
   CHECK_EQ(inputs.size(), 1);
diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc
index 46143d1..0647ec9 100644
--- a/src/relay/op/type_relations.cc
+++ b/src/relay/op/type_relations.cc
@@ -25,6 +25,7 @@
 #include "./type_relations.h"
 
 #include <tvm/arith/analyzer.h>
+#include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
 #include <tvm/tir/op.h>
@@ -146,5 +147,19 @@ Array<IndexExpr> RankShape(const Array<IndexExpr>& shape)
{
   }
 }
 
+bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 1);
+  auto tt = types[0].as<TensorTypeNode>();
+  if (tt == nullptr) {
+    return false;
+  }
+  const auto* param = attrs.as<ShapeOfAttrs>();
+  CHECK(param != nullptr);
+  auto rank_shape = RankShape(tt->shape);
+  reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
+  return true;
+}
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h
index acd4b2d..5ab8b12 100644
--- a/src/relay/op/type_relations.h
+++ b/src/relay/op/type_relations.h
@@ -79,6 +79,18 @@ bool IdentityCompRel(const Array<Type>& types, int num_inputs,
const Attrs& attr
 
 Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
 
+/*!
+ * \brief The shape of type relation.
+ *
+ * \param types The input and output types to the relation.
+ * \param num_inputs The number of input arguments.
+ * \param attrs The attributes
+ * \param reporter The reporter.
+ * \return true whether relation has been resolved.
+ */
+bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter);
+
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc
new file mode 100644
index 0000000..af33100
--- /dev/null
+++ b/src/relay/op/vm/vm.cc
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/vm/vm.cc
+ * \brief Dialect operators for Relay VM.
+ */
+
+#include <topi/elemwise.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/data_type.h>
+
+#include "../../transforms/infer_layout_util.h"
+#include "../op_common.h"
+#include "../type_relations.h"
+
+namespace tvm {
+namespace relay {
+
+RELAY_REGISTER_OP("vm.shape_of")
+    .describe(R"code(Get the shape of an input tensor.
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .add_argument("tensor", "Tensor", "The input tensor")
+    .add_type_rel("ShapeOf", ShapeOfRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
+  auto attrs = make_object<ShapeOfAttrs>();
+  attrs->dtype = DataType::Int(64);
+  static const Op& op = Op::Get("vm.shape_of");
+  return Call(op, {expr}, Attrs(attrs), {});
+});
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index b2eab8f..50de871 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -81,6 +81,7 @@ class ConstantFolder : public ExprMutator {
       : executor_(executor),
         module_(module),
         shape_of_op_(Op::Get("shape_of")),
+        vm_shape_of_op_(Op::Get("vm.shape_of")),
         invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
         shape_func_op_(Op::Get("memory.shape_func")),
         alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
@@ -123,7 +124,7 @@ class ConstantFolder : public ExprMutator {
     // skip stateful ops.
     if (op_stateful.get(GetRef<Op>(op), false)) return res;
     // Try to evaluate shape_of op
-    if (call->op == shape_of_op_) {
+    if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
       return EvaluateShapeOf(res, origin_args, call->attrs);
     }
 
@@ -166,6 +167,7 @@ class ConstantFolder : public ExprMutator {
 
   // Cache the following ops for equivalence checking in this pass.
   const Op& shape_of_op_;
+  const Op& vm_shape_of_op_;
   const Op& invoke_tvm_op_;
   const Op& shape_func_op_;
   const Op& alloc_tensor_op_;
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 65b1a2f..f520404 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -417,6 +417,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr)
{
       fields.push_back(instr.pc_offset);
       break;
     }
+    case Opcode::ShapeOf: {
+      // Number of fields = 2
+      fields.assign({instr.shape_of.tensor, instr.dst});
+      break;
+    }
     default:
       LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
       break;
@@ -683,6 +688,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer&
instr) {
       DCHECK_EQ(instr.fields.size(), 1U);
       return Instruction::Goto(instr.fields[0]);
     }
+    case Opcode::ShapeOf: {
+      // Number of fields = 2
+      DCHECK_EQ(instr.fields.size(), 2U);
+      return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
+    }
     default:
       LOG(FATAL) << "Invalid opcode" << instr.opcode;
       return Instruction();
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 0c0ca35..6b10a89 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -145,6 +145,9 @@ Instruction::Instruction(const Instruction& instr) {
     case Opcode::AllocStorage:
       this->alloc_storage = instr.alloc_storage;
       return;
+    case Opcode::ShapeOf:
+      this->shape_of.tensor = instr.shape_of.tensor;
+      return;
     default:
       std::ostringstream out;
       out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -239,6 +242,9 @@ Instruction& Instruction::operator=(const Instruction& instr)
{
     case Opcode::AllocStorage:
       this->alloc_storage = instr.alloc_storage;
       return *this;
+    case Opcode::ShapeOf:
+      this->shape_of.tensor = instr.shape_of.tensor;
+      return *this;
     default:
       std::ostringstream out;
       out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -258,6 +264,7 @@ Instruction::~Instruction() {
     case Opcode::Goto:
     case Opcode::LoadConsti:
     case Opcode::AllocStorage:
+    case Opcode::ShapeOf:
     case Opcode::Fatal:
       return;
     case Opcode::AllocTensor:
@@ -351,6 +358,14 @@ Instruction Instruction::AllocStorage(RegName size, Index alignment,
DLDataType
   return instr;
 }
 
+Instruction Instruction::ShapeOf(RegName tensor, Index dst) {
+  Instruction instr;
+  instr.op = Opcode::ShapeOf;
+  instr.dst = dst;
+  instr.shape_of.tensor = tensor;
+  return instr;
+}
+
 Instruction Instruction::AllocADT(Index tag, Index num_fields,
                                   const std::vector<RegName>& datatype_fields,
Index dst) {
   Instruction instr;
@@ -585,6 +600,10 @@ void InstructionPrint(std::ostream& os, const Instruction& instr)
{
          << DLDataType2String(instr.alloc_storage.dtype_hint);
       break;
     }
+    case Opcode::ShapeOf: {
+      os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
+      break;
+    }
     default:
       LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
       break;
@@ -1057,6 +1076,18 @@ void VirtualMachine::RunLoop() {
         pc_++;
         goto main_loop;
       }
+      case Opcode::ShapeOf: {
+        auto input = ReadRegister(instr.shape_of.tensor);
+        NDArray input_array = Downcast<NDArray>(input);
+        int ndim = input_array->ndim;
+        auto out_tensor = NDArray::Empty({ndim}, {kDLInt, 64, 1}, {kDLCPU, 0});
+        for (int i = 0; i < ndim; ++i) {
+          reinterpret_cast<int64_t*>(out_tensor->data)[i] = input_array->shape[i];
+        }
+        WriteRegister(instr.dst, out_tensor);
+        pc_++;
+        goto main_loop;
+      }
       case Opcode::Ret: {
         // If we have hit the point from which we started
         // running, we should return to the caller breaking
diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py
index 5d20651..95e6c6f 100644
--- a/tests/python/relay/test_vm_serialization.py
+++ b/tests/python/relay/test_vm_serialization.py
@@ -19,7 +19,6 @@
 import numpy as np
 
 import tvm
-from tvm import te
 from tvm.runtime import vm as _vm
 from tvm.relay import vm as rly_vm
 from tvm import relay
@@ -41,11 +40,15 @@ def create_exec(f, target="llvm", params=None):
         return executable
 
 
-def veval(vm, *args, ctx=tvm.cpu()):
-    assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine"
-    ret = vm.run(*args)
-    return ret
-
+def get_serialized_output(mod, *data, params=None, target="llvm",
+                          ctx=tvm.cpu()):
+    exe = create_exec(mod, target, params=params)
+    code, lib = exe.save()
+    des_exec = _vm.Executable.load_exec(code, lib)
+    des_vm = _vm.VirtualMachine(des_exec)
+    des_vm.init(ctx)
+    result = des_vm.run(*data)
+    return result
 
 def run_network(mod,
                 params,
@@ -56,24 +59,16 @@ def run_network(mod,
         result = ex.evaluate()(data, **params)
         return result.asnumpy().astype(dtype)
 
-    def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
-        exe = create_exec(mod, target, params=params)
-        code, lib = exe.save()
-        des_exec = _vm.Executable.load_exec(code, lib)
-        des_vm = _vm.VirtualMachine(des_exec)
-        des_vm.init(ctx)
-        result = des_vm.run(data)
-        return result.asnumpy().astype(dtype)
-
     data = np.random.uniform(size=data_shape).astype(dtype)
     target = "llvm"
     ctx = tvm.cpu(0)
 
     tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params,
                             target, ctx, dtype)
-    vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)), params,
-                                   target, ctx, dtype)
-    tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
+    vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)),
+                                   params=params, target=target, ctx=ctx)
+    tvm.testing.assert_allclose(vm_out.asnumpy().astype(dtype), tvm_out,
+                                rtol=1e-5, atol=1e-5)
 
 
 def test_serializer():
@@ -143,7 +138,7 @@ def test_save_load():
     des_vm = _vm.VirtualMachine(des_exec)
     des_vm.init(tvm.cpu())
 
-    res = veval(des_vm, x_data)
+    res = des_vm.run(x_data)
     tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
 
 
@@ -151,14 +146,8 @@ def test_const():
     c = relay.const(1.0, "float32")
     x = relay.var('x', shape=(10, 10), dtype='float32')
     f = relay.Function([x], x + c)
-    exe = create_exec(f)
-    code, lib = exe.save()
-    assert isinstance(code, bytearray)
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
     x_data = np.random.rand(10, 10).astype('float32')
-    res = veval(des_vm, x_data)
+    res = get_serialized_output(f, x_data)
     tvm.testing.assert_allclose(res.asnumpy(), x_data + 1)
 
 
@@ -172,18 +161,12 @@ def test_if():
     x_data = np.random.rand(10, 10).astype('float32')
     y_data = np.random.rand(10, 10).astype('float32')
 
-    exe = create_exec(f)
-    code, lib = exe.save()
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
-
     # same
-    res = veval(des_vm, x_data, x_data)
+    res = get_serialized_output(f, x_data, x_data)
     tvm.testing.assert_allclose(res.asnumpy(), x_data)
 
     # diff
-    res = veval(des_vm, x_data, y_data)
+    res = get_serialized_output(f, x_data, y_data)
     tvm.testing.assert_allclose(res.asnumpy(), y_data)
 
 
@@ -208,13 +191,7 @@ def test_loop():
     aarg = relay.var('accum', shape=[], dtype='int32')
     mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
 
-    exe = create_exec(mod)
-    code, lib = exe.save()
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
-
-    result = veval(des_vm, i_data, accum_data)
+    result = get_serialized_output(mod, i_data, accum_data)
     tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
 
 
@@ -225,13 +202,7 @@ def test_tuple():
     i_data = np.random.rand(41).astype('float32')
     j_data = np.random.rand(10).astype('float32')
 
-    exe = create_exec(f)
-    code, lib = exe.save()
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
-
-    result = veval(des_vm, (i_data, j_data))
+    result = get_serialized_output(f, (i_data, j_data))
     tvm.testing.assert_allclose(result.asnumpy(), j_data)
 
 
@@ -246,13 +217,7 @@ def test_adt_list():
     f = relay.Function([], l321)
     mod["main"] = f
 
-    exe = create_exec(mod)
-    code, lib = exe.save()
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
-
-    result = veval(des_vm)
+    result = get_serialized_output(mod)
     assert len(result) == 2
     assert len(result[1]) == 2
     assert len(result[1][1]) == 2
@@ -292,15 +257,8 @@ def test_adt_compose():
     f = relay.Function([y], add_two_body)
     mod["main"] = f
 
-    exe = create_exec(mod)
-    code, lib = exe.save()
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
-
     x_data = np.array(np.random.rand()).astype('float32')
-    result = veval(des_vm, x_data)
-
+    result = get_serialized_output(mod, x_data)
     tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
 
 
@@ -312,13 +270,7 @@ def test_closure():
     clo = ff(relay.const(1.0))
     main = clo(relay.const(2.0))
 
-    exe = create_exec(main)
-    code, lib = exe.save()
-    des_exec = _vm.Executable.load_exec(code, lib)
-    des_vm = _vm.VirtualMachine(des_exec)
-    des_vm.init(tvm.cpu())
-
-    res = veval(des_vm)
+    res = get_serialized_output(main)
     tvm.testing.assert_allclose(res.asnumpy(), 3.0)
 
 
@@ -332,6 +284,20 @@ def test_mobilenet():
     run_network(mod, params)
 
 
+def test_vm_shape_of():
+    x = relay.var('x', shape=(relay.Any(), relay.Any(), relay.Any()), dtype="float32")
+    relu_x = relay.nn.relu(x)
+    data = np.random.uniform(size=(2, 3, 4)).astype('float32')
+    args = [data]
+
+    newshape_var = relay.var('newshape', shape=(2,), dtype='int64')
+    args.append(np.array((1, -1), dtype='int64'))
+    main = relay.reshape(relu_x, newshape=newshape_var)
+
+    res = get_serialized_output(main, *args).asnumpy()
+    tvm.testing.assert_allclose(res.flatten(), data.flatten())
+
+
 if __name__ == "__main__":
     test_serializer()
     test_save_load()
@@ -344,3 +310,4 @@ if __name__ == "__main__":
     test_closure()
     test_resnet()
     test_mobilenet()
+    test_vm_shape_of()


Mime
View raw message