tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [TFLite Runtime] Fix bug and re-enable RPC execution test (#5436)
Date Fri, 15 May 2020 03:17:06 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new a400f82  [TFLite Runtime] Fix bug and re-enable RPC execution test (#5436)
a400f82 is described below

commit a400f825281f3c6f0688e8b16deea4ba12ee6bb5
Author: Michal Piszczek <imichaljp@gmail.com>
AuthorDate: Thu May 14 20:16:57 2020 -0700

    [TFLite Runtime] Fix bug and re-enable RPC execution test (#5436)
---
 src/runtime/contrib/tflite/tflite_runtime.cc |   8 +-
 src/runtime/contrib/tflite/tflite_runtime.h  |   3 +
 src/runtime/module.cc                        |   2 +
 tests/python/contrib/test_tflite_runtime.py  | 202 ++++++++++++++++-----------
 tests/scripts/task_config_build_cpu.sh       |   3 +
 5 files changed, 135 insertions(+), 83 deletions(-)

diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc
index 53d7754..8b34e90 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.cc
+++ b/src/runtime/contrib/tflite/tflite_runtime.cc
@@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
 void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) {
   const char* buffer = tflite_model_bytes.c_str();
   size_t buffer_size = tflite_model_bytes.size();
+  // The buffer used to construct the model must be kept alive for
+  // dependent interpreters to be used.
+  flatBuffersBuffer_ = std::unique_ptr<char[]>(new char[buffer_size]);
+  std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size);
   std::unique_ptr<tflite::FlatBufferModel> model =
-      tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
+      tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size);
   tflite::ops::builtin::BuiltinOpResolver resolver;
   // Build interpreter
   TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
@@ -173,5 +177,7 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes,
TVMContext ctx
 TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv)
{
   *rv = TFLiteRuntimeCreate(args[0], args[1]);
 });
+
+TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate);
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h
index f61f6ee..f3e3bd9 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.h
+++ b/src/runtime/contrib/tflite/tflite_runtime.h
@@ -26,6 +26,7 @@
 #define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_
 
 #include <dlpack/dlpack.h>
+#include <tensorflow/lite/interpreter.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 
@@ -93,6 +94,8 @@ class TFLiteRuntime : public ModuleNode {
    */
   NDArray GetOutput(int index) const;
 
+  // Buffer backing the interpreter's model
+  std::unique_ptr<char[]> flatBuffersBuffer_;
   // TFLite interpreter
   std::unique_ptr<tflite::Interpreter> interpreter_;
   // TVM context
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index be75ff2..46ef6fa 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) {
     f_name = "device_api.opencl";
   } else if (target == "mtl" || target == "metal") {
     f_name = "device_api.metal";
+  } else if (target == "tflite") {
+    f_name = "target.runtime.tflite";
   } else if (target == "vulkan") {
     f_name = "device_api.vulkan";
   } else if (target == "stackvm") {
diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py
index 8c883b0..1b911b7 100644
--- a/tests/python/contrib/test_tflite_runtime.py
+++ b/tests/python/contrib/test_tflite_runtime.py
@@ -14,92 +14,130 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
+
 import tvm
 from tvm import te
 import numpy as np
 from tvm import rpc
 from tvm.contrib import util, tflite_runtime
-# import tensorflow as tf
-# import tflite_runtime.interpreter as tflite
-
-
-def skipped_test_tflite_runtime():
-
-    def create_tflite_model():
-        root = tf.Module()
-        root.const = tf.constant([1., 2.], tf.float32)
-        root.f = tf.function(lambda x: root.const * x)
-
-        input_signature = tf.TensorSpec(shape=[2,  ], dtype=tf.float32)
-        concrete_func = root.f.get_concrete_function(input_signature)
-        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
-        tflite_model = converter.convert()
-        return tflite_model
-
-
-    def check_local():
-        tflite_fname = "model.tflite"
-        tflite_model = create_tflite_model()
-        temp = util.tempdir()
-        tflite_model_path = temp.relpath(tflite_fname)
-        open(tflite_model_path, 'wb').write(tflite_model)
-
-        # inference via tflite interpreter python apis
-        interpreter = tflite.Interpreter(model_path=tflite_model_path)
-        interpreter.allocate_tensors()
-        input_details = interpreter.get_input_details()
-        output_details = interpreter.get_output_details()
-
-        input_shape = input_details[0]['shape']
-        tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
-        interpreter.set_tensor(input_details[0]['index'], tflite_input)
-        interpreter.invoke()
-        tflite_output = interpreter.get_tensor(output_details[0]['index'])
-
-        # inference via tvm tflite runtime
-        with open(tflite_model_path, 'rb') as model_fin:
-            runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
-            runtime.set_input(0, tvm.nd.array(tflite_input))
-            runtime.invoke()
-            out = runtime.get_output(0)
-            np.testing.assert_equal(out.asnumpy(), tflite_output)
-
-
-    def check_remote():
-        tflite_fname = "model.tflite"
-        tflite_model = create_tflite_model()
-        temp = util.tempdir()
-        tflite_model_path = temp.relpath(tflite_fname)
-        open(tflite_model_path, 'wb').write(tflite_model)
-
-        # inference via tflite interpreter python apis
-        interpreter = tflite.Interpreter(model_path=tflite_model_path)
-        interpreter.allocate_tensors()
-        input_details = interpreter.get_input_details()
-        output_details = interpreter.get_output_details()
-
-        input_shape = input_details[0]['shape']
-        tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
-        interpreter.set_tensor(input_details[0]['index'], tflite_input)
-        interpreter.invoke()
-        tflite_output = interpreter.get_tensor(output_details[0]['index'])
-
-        # inference via remote tvm tflite runtime
-        server = rpc.Server("localhost")
-        remote = rpc.connect(server.host, server.port)
-        ctx = remote.cpu(0)
-        a = remote.upload(tflite_model_path)
-
-        with open(tflite_model_path, 'rb') as model_fin:
-            runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
-            runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
-            runtime.invoke()
-            out = runtime.get_output(0)
-            np.testing.assert_equal(out.asnumpy(), tflite_output)
-
-    check_local()
-    check_remote()
+
+
+def _create_tflite_model():
+    if not tvm.runtime.enabled("tflite"):
+        print("skip because tflite runtime is not enabled...")
+        return
+    if not tvm.get_global_func("tvm.tflite_runtime.create", True):
+        print("skip because tflite runtime is not enabled...")
+        return
+
+    try:
+        import tensorflow as tf
+    except ImportError:
+        print('skip because tensorflow not installed...')
+        return
+
+    root = tf.Module()
+    root.const = tf.constant([1., 2.], tf.float32)
+    root.f = tf.function(lambda x: root.const * x)
+
+    input_signature = tf.TensorSpec(shape=[2,  ], dtype=tf.float32)
+    concrete_func = root.f.get_concrete_function(input_signature)
+    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+    tflite_model = converter.convert()
+    return tflite_model
+
+
+@pytest.mark.skip('skip because accessing output tensor is flakey')
+def test_local():
+    if not tvm.runtime.enabled("tflite"):
+        print("skip because tflite runtime is not enabled...")
+        return
+    if not tvm.get_global_func("tvm.tflite_runtime.create", True):
+        print("skip because tflite runtime is not enabled...")
+        return
+
+    try:
+        import tensorflow as tf
+    except ImportError:
+        print('skip because tensorflow not installed...')
+        return
+
+    tflite_fname = "model.tflite"
+    tflite_model = _create_tflite_model()
+    temp = util.tempdir()
+    tflite_model_path = temp.relpath(tflite_fname)
+    open(tflite_model_path, 'wb').write(tflite_model)
+
+    # inference via tflite interpreter python apis
+    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
+    interpreter.allocate_tensors()
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    input_shape = input_details[0]['shape']
+    tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
+    interpreter.set_tensor(input_details[0]['index'], tflite_input)
+    interpreter.invoke()
+    tflite_output = interpreter.get_tensor(output_details[0]['index'])
+
+    # inference via tvm tflite runtime
+    with open(tflite_model_path, 'rb') as model_fin:
+        runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
+        runtime.set_input(0, tvm.nd.array(tflite_input))
+        runtime.invoke()
+        out = runtime.get_output(0)
+        np.testing.assert_equal(out.asnumpy(), tflite_output)
+
+
+def test_remote():
+    if not tvm.runtime.enabled("tflite"):
+        print("skip because tflite runtime is not enabled...")
+        return
+    if not tvm.get_global_func("tvm.tflite_runtime.create", True):
+        print("skip because tflite runtime is not enabled...")
+        return
+
+    try:
+        import tensorflow as tf
+    except ImportError:
+        print('skip because tensorflow not installed...')
+        return
+
+    tflite_fname = "model.tflite"
+    tflite_model = _create_tflite_model()
+    temp = util.tempdir()
+    tflite_model_path = temp.relpath(tflite_fname)
+    open(tflite_model_path, 'wb').write(tflite_model)
+
+    # inference via tflite interpreter python apis
+    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
+    interpreter.allocate_tensors()
+    input_details = interpreter.get_input_details()
+    output_details = interpreter.get_output_details()
+
+    input_shape = input_details[0]['shape']
+    tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
+    interpreter.set_tensor(input_details[0]['index'], tflite_input)
+    interpreter.invoke()
+    tflite_output = interpreter.get_tensor(output_details[0]['index'])
+
+    # inference via remote tvm tflite runtime
+    server = rpc.Server("localhost")
+    remote = rpc.connect(server.host, server.port)
+    ctx = remote.cpu(0)
+    a = remote.upload(tflite_model_path)
+
+    with open(tflite_model_path, 'rb') as model_fin:
+        runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
+        runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
+        runtime.invoke()
+        out = runtime.get_output(0)
+        np.testing.assert_equal(out.asnumpy(), tflite_output)
+
+    server.terminate()
+
 
 if __name__ == "__main__":
-    # skipped_test_tflite_runtime()
-    pass
+    test_local()
+    test_remote()
diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh
index 9c1cf28..ce545bd 100755
--- a/tests/scripts/task_config_build_cpu.sh
+++ b/tests/scripts/task_config_build_cpu.sh
@@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake
 echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
 echo set\(USE_VTA_TSIM ON\) >> config.cmake
 echo set\(USE_VTA_FSIM ON\) >> config.cmake
+echo set\(USE_TFLITE ON\) >> config.cmake
+echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake
+echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake


Mime
View raw message