tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] zhiics commented on a change in pull request #4964: [Torch] Add initial control flow support
Date Sun, 01 Mar 2020 01:56:45 GMT
zhiics commented on a change in pull request #4964: [Torch] Add initial control flow support

URL: https://github.com/apache/incubator-tvm/pull/4964#discussion_r386069549
 
 

 ##########
 File path: python/tvm/relay/frontend/pytorch.py
 ##########
 @@ -955,7 +1025,100 @@ def parse_params(graph, state_dict):
     return params, param_tensors
 
 
-def parse_operators(operators, outputs, output_index_map, ret_name):
+def convert_block(block, outputs, output_index_map):
+    """ Translate Torch "Block", used for prim::If and prim::Loop """
+    ops = _get_operator_nodes(block.nodes())
+    ret_names = _get_input_names(block.returnNode())
+    return convert_operators(ops, outputs, output_index_map, ret_names)
+
+
+def convert_if(if_node, outputs, output_index_map):
+    """ Translate Torch prim::If to Relay If """
+    cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+    blocks = list(if_node.blocks())
+    true_branch = convert_block(blocks[0], outputs, output_index_map)
+    false_branch = convert_block(blocks[1], outputs, output_index_map)
+    assert len(true_branch) == 1 and len(false_branch) == 1
+    return _expr.If(cond, true_branch[0], false_branch[0])
+
+
+def convert_loop(loop_node, outputs, output_index_map):
+    """ Translate Torch prim::Loop to Relay while_loop """
+    def get_input(index):
+        ivalue = loop_node.inputsAt(index)
+        inode = ivalue.node()
+        if inode.kind() == "prim::Constant":
+            return _expr.const(_get_constant(inode))
+        var_name = ivalue.debugName()
+        assert var_name in output_index_map
+        return _wrap_const(outputs[output_index_map[var_name]])
+
+    # Refer to the spec for prim::Loop below
+    # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
+    # The first input: %max_trip_count
+    # The second input: %initial_condition
+    # The rest of input: loop variables
+    max_loop_count = get_input(0)
+    init_cond = get_input(1)
+    num_loop_var = len(list(loop_node.inputs())) - 2
+    init_vals = [get_input(i + 2) for i in range(num_loop_var)]
+
+    # For loop (not while loop) has always %initial_condition being 1
+    is_for_loop = isinstance(init_cond, _expr.Constant)
 
 Review comment:
   Thanks. But technically, it is still possible that `max_trip_count` of a `for_loop` is
`sys.maxsize`, right? I agree this should be very rare but I am not sure if this is good enough.
It looks that Torch used `%i` to determine if it is a `for` or `while`. `%i` is not used in
`while`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message