tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] maheshambule commented on a change in pull request #5082: [Relay, Topi] [TF, MXNet] Unravel Index operator
Date Fri, 20 Mar 2020 20:29:35 GMT
maheshambule commented on a change in pull request #5082: [Relay, Topi] [TF, MXNet] Unravel
Index operator
URL: https://github.com/apache/incubator-tvm/pull/5082#discussion_r395875063
 
 

 ##########
 File path: src/relay/op/tensor/transform.cc
 ##########
 @@ -2703,5 +2701,67 @@ RELAY_REGISTER_OP("one_hot")
 .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
 
+/* relay.unravel_index */
+bool UnRavelIndexRel(const Array<Type>& types,
+                     int num_inputs,
+                     const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+
+  const auto* indices = types[0].as<TensorTypeNode>();
+  if (indices == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "unravel_index: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+  CHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of
integer";
+
+  const auto* shape = types[1].as<TensorTypeNode>();
+  if (shape == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "unravel_index: expect input type to be TensorType but get " << types[1];
+    return false;
+  }
+  CHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer";
+
+  Array<IndexExpr> indices_shape;
+  Array<IndexExpr> shape_shape;
+  indices_shape = indices->shape;
+  shape_shape = shape->shape;
+
+  Array<IndexExpr> oshape;
+  oshape.push_back(shape_shape[0]);
+  if (indices_shape.size() != 0) {
+    oshape.push_back(indices_shape[0]);
+  }
+  reporter->Assign(types[2], TensorType(oshape, indices->dtype));
+  return true;
+}
+
+Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs,
+                                      const Array<te::Tensor>& inputs,
+                                      const Type& out_type) {
+  return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
+}
+
+Expr MakeUnRavelIndex(Expr data, Expr shape) {
+  static const Op& op = Op::Get("unravel_index");
+  return CallNode::make(op, {data, shape}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex);
+
+RELAY_REGISTER_OP("unravel_index")
+    .describe(
+        R"code(Converts a flat index or array of flat indices into a tuple of coordinate
arrays.
+Example::
+    -   unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .set_support_level(3)
+    .add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
+    .set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
 
 Review comment:
   Fixed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message