tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] areusch commented on a change in pull request #5484: [REFACTOR][RPC][PROCOTOL-CHANGE] Modularize the RPC infra
Date Mon, 04 May 2020 17:34:03 GMT

areusch commented on a change in pull request #5484:
URL: https://github.com/apache/incubator-tvm/pull/5484#discussion_r419606101



##########
File path: jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java
##########
@@ -38,53 +38,14 @@
    * @return Runtime graph module that can be used to execute the graph.
    */
   public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) {
-    Module graphModule = null;
-    if (ctx.deviceType >= RPC.RPC_SESS_MASK) {
-      if (!(ctx instanceof  TVMRemoteContext)) {
-        throw new IllegalArgumentException(
-            "Looks like you are using remote context with no RPCSession bind."
-            + "Use session.context instead.");
-      }
-      RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession;
-      // check arguments
-      if (!"rpc".equals(libmod.typeKey())) {
-        throw new IllegalArgumentException("libmod.typeKey != rpc");
-      }
-      final int sessIndex = (int) ((Function) reflectionStaticCall(
-          RPC.class, "getApi", "_SessTableIndex"))
-          .pushArg(libmod).invoke().asLong();
-      if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) {
-        throw new IllegalArgumentException(String.format(
-            "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d",
-            sessIndex, reflectionGetField(rpcSession, "tblIndex")));
-      }
-
-      Function rpcModuleHandle = (Function) reflectionStaticCall(
-          RPC.class, "getApi","_ModuleHandle");
-      if (rpcModuleHandle == null) {
-        throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle."
-            + "Did you compile tvm_runtime with the correct version?");
-      }
-
-      Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create");
-      if (fcreate == null) {
-        throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create."
-            + "Did you compile tvm_runtime with correct version?");
-      }
-
-      TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke();
-      graphModule = fcreate.call(graphJson, hmod,
-          ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule();
-    } else {
-      Function fcreate = Function.getFunction("tvm.graph_runtime.create");
-      if (fcreate == null) {
-        throw new RuntimeException("Cannot find global function tvm.graph_runtime.create."
-            + "Did you compile tvm_runtime with correct version?");
-      }
-      graphModule = fcreate.pushArg(graphJson)
-          .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId)
-          .invoke().asModule();
+    Function fcreate = Function.getFunction("tvm.graph_runtime.create");
+    if (fcreate == null) {

Review comment:
       would be better to add this to a new helper function `Function.getFunctionOrThrow()`.
doesn't need to be in this PR though




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message