tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: [Relay, Topi, TF Frontend] Isfinite operator (#4981)
Date Mon, 23 Mar 2020 11:17:23 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 9037f4e  [Relay, Topi, TF Frontend] Isfinite operator (#4981)
9037f4e is described below

commit 9037f4ec988b6f9669b009b33a576500b1720c6d
Author: Mahesh Ambule <15611578+maheshambule@users.noreply.github.com>
AuthorDate: Mon Mar 23 16:47:12 2020 +0530

    [Relay, Topi, TF Frontend] Isfinite operator (#4981)
    
    * isfinite doc update
    
    * isfinit expr
    
    * isfinit expr
    
    * isfinite schedule reg
    
    * isfinite python binding
    
    * isfinite python binding
    
    * relay register isfinite
    
    * isfinite type relation
    
    * intrin isfinite
    
    * topi isfinite
    
    * testcase topi isfinite
    
    * tf frontend isfinite
    
    * tf frontend isfinite testcase
    
    * test case relay isfinite
    
    * small fixes
    
    * test forward tf isfinite
    
    * test cases injective for cuda
    
    * remove float16 test case
    
    * add support for isinf
    
    * remove unwanted import
    
    * fix conflict
---
 docs/api/python/topi.rst                         |  4 +++
 docs/frontend/tensorflow.rst                     |  2 ++
 include/tvm/tir/expr.h                           |  2 ++
 include/tvm/tir/op.h                             | 21 ++++++++++++++
 python/tvm/relay/frontend/tensorflow.py          |  2 ++
 python/tvm/relay/op/_tensor.py                   |  2 ++
 python/tvm/relay/op/tensor.py                    | 32 +++++++++++++++++++++
 python/tvm/te/__init__.py                        |  3 +-
 python/tvm/tir/__init__.py                       |  3 +-
 python/tvm/tir/op.py                             | 32 +++++++++++++++++++++
 src/relay/op/tensor/unary.cc                     | 18 ++++++++++++
 src/relay/op/type_relations.cc                   | 12 ++++++++
 src/relay/op/type_relations.h                    |  5 ++++
 src/target/intrin_rule.cc                        | 16 +++++++++++
 src/tir/ir/op.cc                                 | 36 ++++++++++++++++++++++++
 tests/python/frontend/tensorflow/test_forward.py | 34 +++++++++++++++++++++-
 tests/python/relay/test_op_level3.py             | 31 +++++++++++++++++++-
 topi/include/topi/elemwise.h                     |  2 ++
 topi/python/topi/math.py                         | 34 ++++++++++++++++++++++
 topi/tests/python/test_topi_math.py              | 32 +++++++++++++++++++++
 20 files changed, 319 insertions(+), 4 deletions(-)

diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index 676dde9..e6a2c38 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -33,6 +33,8 @@ List of operators
    topi.round
    topi.abs
    topi.isnan
+   topi.isfinite
+   topi.isinf
    topi.exp
    topi.tanh
    topi.log
@@ -134,6 +136,8 @@ topi
 .. autofunction:: topi.round
 .. autofunction:: topi.abs
 .. autofunction:: topi.isnan
+.. autofunction:: topi.isfinite
+.. autofunction:: topi.isinf
 .. autofunction:: topi.exp
 .. autofunction:: topi.tanh
 .. autofunction:: topi.log
diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst
index 80230c6..45db9e4 100644
--- a/docs/frontend/tensorflow.rst
+++ b/docs/frontend/tensorflow.rst
@@ -160,6 +160,8 @@ Supported Ops
 - Greater
 - GreaterEqual
 - Identity
+- IsFinite
+- IsInf
 - LeakyRelu
 - LeftShift
 - Less
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 9efc5d4..90fef87 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -832,6 +832,8 @@ class CallNode : public PrimExprNode {
   static constexpr const char* glsl_texture_store = "glsl_texture_store";
   static constexpr const char* prefetch = "prefetch";
   static constexpr const char* isnan = "isnan";
+  static constexpr const char* isfinite = "isfinite";
+  static constexpr const char* isinf = "isinf";
 
   /*! \brief Vectorizable intrinsic list. */
   static const char* vectorizable_intrinsics[];
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index afdc5fc..b54aa9a 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -84,6 +84,13 @@ TVM_DLL PrimExpr max_value(const DataType& dtype);
 TVM_DLL PrimExpr min_value(const DataType& dtype);
 
 /*!
+ * Get the value of infinity.
+ * \param dtype The data type.
+ * \return the infinity value in this format.
+ */
+TVM_DLL PrimExpr infinity(const DataType& dtype);
+
+/*!
  * \brief cast value to type.
  *
  * \param t the target type.
@@ -440,6 +447,20 @@ TVM_DLL PrimExpr abs(PrimExpr x);
 TVM_DLL PrimExpr isnan(PrimExpr x);
 
 /*!
+ * \brief Check if x is finite.
+ * \param x The input data
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isfinite(PrimExpr x);
+
+/*!
+ * \brief Check if x is infinite.
+ * \param x The input data
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isinf(PrimExpr x);
+
+/*!
  * \brief sum of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 4221cac..9cdd68b 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1667,6 +1667,8 @@ _convert_map = {
     'Greater'                           : _broadcast('greater'),
     'GreaterEqual'                      : _broadcast('greater_equal'),
     'Identity'                          : _identity(),
+    'IsFinite'                          : AttrCvt('isfinite'),
+    'IsInf'                             : AttrCvt('isinf'),
     'LeakyRelu'                         : AttrCvt('leaky_relu'),
     'LeftShift'                         : AttrCvt('left_shift'),
     'Less'                              : _broadcast('less'),
diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py
index 4480849..4b5eada 100644
--- a/python/tvm/relay/op/_tensor.py
+++ b/python/tvm/relay/op/_tensor.py
@@ -66,6 +66,8 @@ register_broadcast_schedule("less")
 register_broadcast_schedule("less_equal")
 register_broadcast_schedule("greater")
 register_broadcast_schedule("greater_equal")
+register_broadcast_schedule("isfinite")
+register_broadcast_schedule("isinf")
 register_injective_schedule("maximum")
 register_injective_schedule("minimum")
 register_injective_schedule("right_shift")
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 7796918..1f481ee 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -1008,3 +1008,35 @@ def ndarray_size(data, dtype="int32"):
         The number of elements of input tensor.
     """
     return _make.ndarray_size(data, dtype)
+
+
+def isfinite(data):
+    """Compute element-wise finiteness of data.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+    return _make.isfinite(data)
+
+
+def isinf(data):
+    """Compute element-wise infiniteness of data.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+    return _make.isinf(data)
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index 83cd189..e88c17a 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -20,7 +20,8 @@
 # expose all operators in tvm tir.op
 from tvm.tir import any, all, min_value, max_value, trace
 from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor,
ceil
-from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
+from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
+from tvm.tir import isnan, isfinite, isinf
 from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
 from tvm.tir import comm_reducer, min, max, sum
 
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index b8a56f8..d4d389a 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -38,7 +38,8 @@ from .op import call_llvm_intrin, all, any, min_value, max_value, trace
 from .op import exp, exp2, exp10, log, log2, log10
 from .op import cos, sin, cosh, sinh, tan, tanh, atan
 from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
-from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
+from .op import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
+from .op import isnan, isfinite, isinf
 from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
 from .op import comm_reducer, min, max, sum
 
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index d82724f..4b703f3 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -706,6 +706,38 @@ def isnan(x):
     return _ffi_api.isnan(x)
 
 
+def isfinite(x):
+    """Check if input value is finite.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.isfinite(x)
+
+
+def isinf(x):
+    """Check if input value is infinite.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return _ffi_api.isinf(x)
+
+
 def power(x, y):
     """x power y
 
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index b9d2473..beb6cd7 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -415,5 +415,23 @@ ElemwiseArbitraryLayout)
 .set_support_level(10)
 .set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
 
+RELAY_REGISTER_UNARY_OP("isfinite")
+.describe(R"code(Returns the finiteness of input, computed element-wise.
+.. math::
+   isfinite(x)
+)code" TVM_ADD_FILELINE)
+.set_support_level(3)
+.add_type_rel("IdentityCompRel", IdentityCompRel)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite));
+
+RELAY_REGISTER_UNARY_OP("isinf")
+.describe(R"code(Returns the infiniteness of input, computed element-wise.
+.. math::
+   isfinite(x)
+)code" TVM_ADD_FILELINE)
+.set_support_level(3)
+.add_type_rel("IdentityCompRel", IdentityCompRel)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf));
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc
index 3c7d148..f9653e2 100644
--- a/src/relay/op/type_relations.cc
+++ b/src/relay/op/type_relations.cc
@@ -136,6 +136,18 @@ bool BroadcastCompRel(const Array<Type>& types,
   return false;
 }
 
+bool IdentityCompRel(const Array<Type>& types,
+                     int num_inputs,
+                     const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  if (auto* t0 = types[0].as<TensorTypeNode>()) {
+    Type out_type = TensorType(GetRef<TensorType>(t0)->shape, DataType::Bool());
+    reporter->Assign(types[1], out_type);
+    return true;
+  }
+  return false;
+}
+
 Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
   if (shape.size() == 0) {
     return {};
diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h
index 80e555b..48a545b 100644
--- a/src/relay/op/type_relations.h
+++ b/src/relay/op/type_relations.h
@@ -79,6 +79,11 @@ bool BroadcastCompRel(const Array<Type>& types,
                       const Attrs& attrs,
                       const TypeReporter& reporter);
 
+bool IdentityCompRel(const Array<Type>& types,
+                 int num_inputs,
+                 const Attrs& attrs,
+                 const TypeReporter& reporter);
+
 Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);
 
 }  // namespace relay
diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc
index ade3bbd..626498b 100644
--- a/src/target/intrin_rule.cc
+++ b/src/target/intrin_rule.cc
@@ -78,6 +78,22 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
     *rv = one / (one + exp(-call->args[0]));
   });
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
+.set_body([](const TVMArgs& args, TVMRetValue* rv){
+    PrimExpr e = args[0];
+    const CallNode* call = e.as<CallNode>();
+    CHECK(call != nullptr);
+    *rv = isfinite(call->args[0]);
+  });
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf")
+.set_body([](const TVMArgs& args, TVMRetValue* rv){
+    PrimExpr e = args[0];
+    const CallNode* call = e.as<CallNode>();
+    CHECK(call != nullptr);
+    *rv = isinf(call->args[0]);
+  });
+
 }  // namespace intrin
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc
index c11fb2a..cf1c24c 100644
--- a/src/tir/ir/op.cc
+++ b/src/tir/ir/op.cc
@@ -180,6 +180,21 @@ PrimExpr min_value(const DataType& dtype) {
   return PrimExpr();
 }
 
+// infinity
+PrimExpr infinity(const DataType& dtype) {
+  using namespace tir;
+  CHECK_EQ(dtype.lanes(), 1);
+  if (dtype.is_float()) {
+    if (dtype.bits() == 64) {
+      return FloatImm(dtype, std::numeric_limits<double>::infinity());
+    } else if (dtype.bits() == 32 || dtype.bits() == 16) {
+      return FloatImm(dtype, std::numeric_limits<float>::infinity());
+    }
+  }
+  LOG(FATAL) << "Cannot decide infinity for type " << dtype;
+  return PrimExpr();
+}
+
 namespace tir {
 template<typename ValueType>
 inline bool ConstPowerHelper(ValueType val, int *shift) {
@@ -575,6 +590,21 @@ PrimExpr isnan(PrimExpr x) {
   }
 }
 
+PrimExpr isinf(PrimExpr x) {
+  DataType t = DataType::Bool(x.dtype().lanes());
+  if (x.dtype().is_int() || x.dtype().is_uint()) {
+    return make_const(t, false);
+  } else if (x.dtype().is_float()) {
+    PrimExpr infX = infinity(x.dtype());
+    return abs(x) == infX && !isnan(x);
+  } else {
+    LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness
ops. Skipping it...";
+    return x;
+  }
+}
+
+PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }
+
 PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::AddNode::make(x, y);
@@ -721,6 +751,12 @@ TVM_REGISTER_GLOBAL("tir.abs")
 TVM_REGISTER_GLOBAL("tir.isnan")
 .set_body_typed(tvm::isnan);
 
+TVM_REGISTER_GLOBAL("tir.isfinite")
+.set_body_typed(tvm::isfinite);
+
+TVM_REGISTER_GLOBAL("tir.isinf")
+.set_body_typed(tvm::isinf);
+
 TVM_REGISTER_GLOBAL("tir.floor")
 .set_body_typed(tvm::floor);
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 3c51977..78d504e 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3152,7 +3152,37 @@ def test_forward_dilation():
     _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
     _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")
 
-# #######################################################################
+
+#######################################################################
+# infinity ops
+# ------------
+def _verify_infiniteness_ops(tf_op, name):
+    """test operator infinity ops"""
+
+    # Only float types are allowed in Tensorflow for isfinite and isinf
+    # float16 is failing on cuda
+    tf_dtypes = ["float32", "float64"]
+    for tf_dtype in tf_dtypes:
+        shape = (8, 8)
+        data = np.random.uniform(size=shape).astype(tf_dtype)
+        data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] =
np.infty
+        data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] =
np.nan
+
+        tf.reset_default_graph()
+        in_data = tf.placeholder(tf_dtype, shape, name="in_data")
+        tf_op(in_data, name=name)
+        compare_tf_with_tvm([data], ['in_data:0'], '{}:0'.format(name))
+
+
+def test_forward_isinf():
+    _verify_infiniteness_ops(tf.is_inf, "isinf")
+
+
+def test_forward_isfinite():
+    _verify_infiniteness_ops(tf.is_finite, "isfinite")
+
+
+#######################################################################
 # Main
 # ----
 if __name__ == '__main__':
@@ -3224,6 +3254,8 @@ if __name__ == '__main__':
     test_forward_squared_difference()
     test_forward_add_n()
     test_forward_floormod()
+    test_forward_isfinite()
+    test_forward_isinf()
     test_forward_unravel_index()
 
     # Reductions
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index fffb1de..4deed42 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -684,6 +684,33 @@ def test_gather_nd():
     verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
 
 
+def _verify_infiniteness_ops(relay_op, ref_op):
+    for dtype in ['float32', 'float16', 'float16', 'int32', 'int16']:
+        shape = (2, 8, 8)
+        x = relay.var("x", relay.TensorType(shape, dtype))
+        y = relay_op(x)
+        yy = run_infer_type(y)
+        assert yy.checked_type == relay.TensorType(shape, "bool")
+
+        data = np.random.uniform(size=shape).astype(dtype)
+        if dtype.startswith('float'):
+            data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)]
= np.infty
+            data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)]
= np.nan
+
+        intrp = create_executor()
+        op_res = intrp.evaluate(y, {x: data})
+        ref_res = ref_op(data)
+        np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
+
+
+def test_isfinite():
+    _verify_infiniteness_ops(relay.isfinite, np.isfinite)
+
+
+def test_isinf():
+    _verify_infiniteness_ops(relay.isinf, np.isinf)
+
+    
 def test_unravel_index():
     def verify_unravel_index(indices, shape, dtype):
         x_data = np.array(indices).astype(dtype)
@@ -751,4 +778,6 @@ if __name__ == "__main__":
     test_tile()
     test_repeat()
     test_gather_nd()
-    test_unravel_index()
+    test_isfinite()
+    test_isinf()
+    test_unravel_index()
\ No newline at end of file
diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h
index 26107ea..88d5732 100644
--- a/topi/include/topi/elemwise.h
+++ b/topi/include/topi/elemwise.h
@@ -60,6 +60,8 @@ TOPI_DECLARE_UNARY_OP(sin);
 TOPI_DECLARE_UNARY_OP(atan);
 TOPI_DECLARE_UNARY_OP(isnan);
 TOPI_DECLARE_UNARY_OP(tanh);
+TOPI_DECLARE_UNARY_OP(isfinite);
+TOPI_DECLARE_UNARY_OP(isinf);
 
 /*
  * \brief Fast_tanh_float implementation from Eigen
diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py
index 3eda88a..6f71ae9 100644
--- a/topi/python/topi/math.py
+++ b/topi/python/topi/math.py
@@ -278,6 +278,40 @@ def isnan(x):
 
 
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
+def isfinite(x):
+    """Check if value of x is finite, element-wise.
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        Input argument.
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return te.compute(x.shape, lambda *i: te.isfinite(x(*i)))
+
+
+@tvm.te.tag_scope(tag=tag.ELEMWISE)
+def isinf(x):
+    """Check if value of x is infinite, element-wise.
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        Input argument.
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return te.compute(x.shape, lambda *i: te.isinf(x(*i)))
+
+
+@tvm.te.tag_scope(tag=tag.ELEMWISE)
 def round(x):
     """Round elements of x to nearest integer.
 
diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py
index 3e58518..a8a56ef 100644
--- a/topi/tests/python/test_topi_math.py
+++ b/topi/tests/python/test_topi_math.py
@@ -113,6 +113,36 @@ def test_ewise():
         for target in get_all_backend():
             check_device(target)
 
+    def test_infiniteness_ops(topi_op, ref_op, name):
+        for dtype in ['float32', 'float64', 'int32', 'int16']:
+            m = te.var("m")
+            l = te.var("l")
+            A = te.placeholder((m, l), dtype=dtype, name="A")
+            B = topi_op(A)
+            assert tuple(B.shape) == tuple(A.shape)
+
+            a_np = np.random.uniform(size=(8, 8)).astype(A.dtype) * 10
+            if dtype.startswith('float'):
+                a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)]
= np.infty
+                a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)]
= np.nan
+            b_np = ref_op(a_np)
+
+            def check_device(device):
+                ctx = tvm.context(device, 0)
+                if not ctx.exist:
+                    print("Skip because %s is not enabled" % device)
+                    return
+                with tvm.target.create(device):
+                    s = topi.testing.get_injective_schedule(device)(B)
+                foo = tvm.build(s, [A, B], device, name=name)
+                a = tvm.nd.array(a_np, ctx)
+                b = tvm.nd.array(np.zeros_like(b_np), ctx)
+                foo(a, b)
+                tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
+
+            for target in get_all_backend():
+                check_device(target)
+
     test_apply(topi.floor, "floor", np.floor, -100, 100)
     test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
     test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
@@ -132,6 +162,8 @@ def test_ewise():
     test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
     test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
     test_isnan(-100, 100)
+    test_infiniteness_ops(topi.isfinite, np.isfinite, 'isifinite')
+    test_infiniteness_ops(topi.isinf, np.isinf, 'isinf')
 
 
 def test_cast():


Mime
View raw message