tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From kevinthe...@apache.org
Subject [incubator-tvm] branch master updated: [Relay] symbolic max_output_size (#5844)
Date Mon, 29 Jun 2020 06:18:49 GMT
This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 78d7992  [Relay] symbolic max_output_size  (#5844)
78d7992 is described below

commit 78d79923756ea9ed4545d2faef7d514a300d3452
Author: Yong Wu <ywu118@alumni.jh.edu>
AuthorDate: Mon Jun 29 14:18:38 2020 +0800

    [Relay] symbolic max_output_size  (#5844)
    
    * symbolic max_output_size
    
    * pylint
    
    * fix ci
---
 include/tvm/relay/attrs/vision.h                 |  8 +--
 python/tvm/relay/frontend/tensorflow.py          |  3 +-
 python/tvm/relay/op/strategy/generic.py          |  4 +-
 python/tvm/relay/op/vision/nms.py                |  8 ++-
 src/relay/op/vision/nms.cc                       | 14 ++---
 tests/python/frontend/tensorflow/test_forward.py | 12 ++--
 tests/python/relay/test_op_level5.py             | 40 ++++++------
 topi/python/topi/cuda/nms.py                     |  2 +-
 topi/python/topi/vision/nms.py                   | 78 +++++++++++++-----------
 topi/tests/python/test_topi_vision.py            | 18 +++---
 10 files changed, 101 insertions(+), 86 deletions(-)

diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h
index 550e24b..2b905f5 100644
--- a/include/tvm/relay/attrs/vision.h
+++ b/include/tvm/relay/attrs/vision.h
@@ -88,7 +88,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs>
{
 
 /*! \brief Attributes used in non_maximum_suppression operator */
 struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs>
{
-  int max_output_size;
+  Optional<Integer> max_output_size;
   double iou_threshold;
   bool force_suppress;
   int top_k;
@@ -99,11 +99,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
   bool invalid_to_bottom;
 
   TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs")
{
-    TVM_ATTR_FIELD(max_output_size)
-        .set_default(-1)
-        .describe(
-            "Max number of output valid boxes for each instance."
-            "By default all valid boxes are returned.");
+    TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes for each instance.");
     TVM_ATTR_FIELD(iou_threshold)
         .set_default(0.5)
         .describe("Non-maximum suppression iou threshold.");
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index b5f34e4..a1a4072 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -625,7 +625,6 @@ def _conv3d(opname):
 def _nms():
     def _impl(inputs, attr, params, mod):
         # Get parameter values
-        # TODO(yongwww) change nms in relay to support symbolic max_output_size
         try:
             max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy()
                                                 .astype("int64"))[0])
@@ -634,7 +633,7 @@ def _nms():
                 max_output_size = _infer_value(inputs[2], params,
                                                mod).asnumpy().astype("int64").tolist()[0]
             except Exception:
-                max_output_size = -1
+                max_output_size = inputs[2]
         iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
         # score_threshold was introduced from V3
         score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) >
4 else 0.0
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index a0dd6bf..e9feee6 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -747,8 +747,10 @@ def get_valid_counts_strategy(attrs, inputs, out_type, target):
 def wrap_compute_nms(topi_compute):
     """wrap nms topi compute"""
     def _compute_nms(attrs, inputs, out_type):
+        max_output_size = inputs[3]
+        if attrs.max_output_size is not None:
+            max_output_size = attrs.max_output_size
         return_indices = bool(get_const_int(attrs.return_indices))
-        max_output_size = get_const_int(attrs.max_output_size)
         iou_threshold = get_const_float(attrs.iou_threshold)
         force_suppress = bool(get_const_int(attrs.force_suppress))
         top_k = get_const_int(attrs.top_k)
diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py
index b60b49a..60ff7a5 100644
--- a/python/tvm/relay/op/vision/nms.py
+++ b/python/tvm/relay/op/vision/nms.py
@@ -91,9 +91,9 @@ def non_max_suppression(data,
         second dimension are like the output of arange(num_anchors)
         if get_valid_counts is not used before non_max_suppression.
 
-    max_output_size : int, optional
+    max_output_size : int or relay.Expr, optional
         Max number of output valid boxes for each instance.
-        By default all valid boxes are returned.
+        Return all valid boxes if the value of max_output_size is less than 0.
 
     iou_threshold : float, optional
         Non-maximum suppression threshold.
@@ -124,9 +124,11 @@ def non_max_suppression(data,
     out : relay.Expr or relay.Tuple
         return relay.Expr if return_indices is disabled, a 3-D tensor
         with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5].
-        if return_indices is True, return relay.Tuple of two 2-D tensors, with
+        If return_indices is True, return relay.Tuple of two 2-D tensors, with
         shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively.
     """
+    if isinstance(max_output_size, int):
+        max_output_size = expr.const(max_output_size, "int32")
     out = _make.non_max_suppression(data,
                                     valid_count,
                                     indices,
diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc
index 7486db7..f9cdaf6 100644
--- a/src/relay/op/vision/nms.cc
+++ b/src/relay/op/vision/nms.cc
@@ -73,7 +73,7 @@ TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
 
 bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
             const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 4);
+  CHECK_EQ(types.size(), 5);
   const auto* data = types[0].as<TensorTypeNode>();
   const auto* valid_count = types[1].as<TensorTypeNode>();
   const NonMaximumSuppressionAttrs* param = attrs.as<NonMaximumSuppressionAttrs>();
@@ -90,18 +90,17 @@ bool NMSRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
     fields.push_back(TensorType(oshape, DataType::Int(32)));
     std::vector<IndexExpr> countshape({dshape[0], 1});
     fields.push_back(TensorType(countshape, DataType::Int(32)));
-    reporter->Assign(types[3], TupleType(Array<Type>(fields)));
+    reporter->Assign(types[4], TupleType(Array<Type>(fields)));
   } else {
-    reporter->Assign(types[3], TensorType(dshape, data->dtype));
+    reporter->Assign(types[4], TensorType(dshape, data->dtype));
   }
   return true;
 }
 
-Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold,
+Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, double iou_threshold,
              bool force_suppress, int top_k, int coord_start, int score_index, int id_index,
              bool return_indices, bool invalid_to_bottom) {
   auto attrs = make_object<NonMaximumSuppressionAttrs>();
-  attrs->max_output_size = max_output_size;
   attrs->iou_threshold = iou_threshold;
   attrs->force_suppress = force_suppress;
   attrs->top_k = top_k;
@@ -111,7 +110,7 @@ Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size,
dou
   attrs->return_indices = return_indices;
   attrs->invalid_to_bottom = invalid_to_bottom;
   static const Op& op = Op::Get("vision.non_max_suppression");
-  return Call(op, {data, valid_count, indices}, Attrs(attrs), {});
+  return Call(op, {data, valid_count, indices, max_output_size}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS);
@@ -122,10 +121,11 @@ be in the format of [class_id, score, left, top, right, bottom]
 or [score, left, top, right, bottom]. Set id_index to be -1 to
 ignore class_id axis.
 )doc" TVM_ADD_FILELINE)
-    .set_num_inputs(3)
+    .set_num_inputs(4)
     .add_argument("data", "Tensor", "Input data.")
     .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
     .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.")
+    .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.")
     .set_support_level(5)
     .add_type_rel("NMS", NMSRel);
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 1a0baf8..182c2d7 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2020,15 +2020,16 @@ def test_forward_crop_and_resize():
 def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, out_size,
dtype="float32"):
     boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
     scores = np.random.uniform(size=score_shape).astype(dtype)
+    max_output_size = np.int32(out_size)
     tf.reset_default_graph()
     in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
     in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
-    tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2,
-                                 max_output_size=out_size, iou_threshold=iou_threshold,
-                                 score_threshold=score_threshold, name="nms")
-    compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'],
+    in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
+    tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3,
+                                 iou_threshold=iou_threshold, score_threshold=score_threshold,
name="nms")
+    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0',
'in_data_3:0'],
                         'nms/NonMaxSuppressionV3:0', mode='vm')
-    compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'],
+    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0',
'in_data_3:0'],
                         'nms/NonMaxSuppressionV3:0', mode='debug')
 
 def test_forward_nms_v3():
@@ -2036,6 +2037,7 @@ def test_forward_nms_v3():
     _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5)
     _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10)
     _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000)
+    _test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7)
 
 
 #######################################################################
diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py
index 14d43c0..265db43 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -283,16 +283,17 @@ def test_get_valid_counts():
 
 
 def test_non_max_suppression():
-    def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res,
-                   iou_threshold=0.5, force_suppress=False, top_k=-1,
-                   check_type_only=False):
+    def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res,
+                   ref_indices_res, iou_threshold=0.5, force_suppress=False,
+                   top_k=-1, check_type_only=False):
         x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
         x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
         x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32"))
-        z = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \
+        x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
+        z = relay.vision.non_max_suppression(x0, x1, x2, x3, \
             iou_threshold=iou_threshold, force_suppress=force_suppress, \
             top_k=top_k, return_indices=False)
-        z_indices = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \
+        z_indices = relay.vision.non_max_suppression(x0, x1, x2, x3, \
                     iou_threshold=iou_threshold, force_suppress=force_suppress, \
                     top_k=top_k, return_indices=True)
         if isinstance(z_indices, relay.expr.TupleWrapper):
@@ -309,30 +310,30 @@ def test_non_max_suppression():
         if check_type_only:
             return
 
-        func = relay.Function([x0, x1, x2], z)
+        func = relay.Function([x0, x1, x2, x3], z)
         func = run_infer_type(func)
-        func_indices = relay.Function([x0, x1, x2], z_indices)
+        func_indices = relay.Function([x0, x1, x2, x3], z_indices)
         func_indices = run_infer_type(func_indices)
         for target, ctx in ctx_list():
             intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
-            op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data)
+            op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
             intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
-            op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data)
+            op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
             if target == 'cuda':
                 return
-            op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data)
+            op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5)
-            op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data)
+            op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5)
 
     np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
                          [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
                          [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
     np_valid_count = np.array([4]).astype("int32")
-
     np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
+    np_max_output_size = -1
 
     np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
                            [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
@@ -341,22 +342,23 @@ def test_non_max_suppression():
     num_anchors = 5
 
     dshape = (te.size_var("n"), num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result,
+    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
np_indices_result,
                force_suppress=True, top_k=2, check_type_only=True)
     dshape = (1, num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result,
+    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
np_indices_result,
                force_suppress=True, top_k=2, check_type_only=False)
 
     np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
-                           [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1],
+                           [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
                            [-1, -1, -1, -1, -1, -1]]])
-    np_indices_result = np.array([[4, 0, 1, -1, -1]])
+    np_indices_result = np.array([[4, 0, -1, -1, -1]])
+    np_max_output_size = 2
     dshape = (te.size_var("n"), num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, dshape, np_result,
+    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
                np_indices_result, check_type_only=True)
     dshape = (1, num_anchors, 6)
-    verify_nms(np_data, np_valid_count, np_indices, dshape, np_result,
-               np_indices_result, top_k=3)
+    verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
+               np_indices_result, top_k=2)
 
 
 def test_multibox_transform_loc():
diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py
index 9e3200a..9f46b95 100644
--- a/topi/python/topi/cuda/nms.py
+++ b/topi/python/topi/cuda/nms.py
@@ -458,7 +458,7 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
             in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
             name="nms",
             tag="nms")
-
+    # TODO(yongwww): Update cuda nms to be consistent with cpu version
     if return_indices:
         return box_indices
 
diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py
index 269c876..1ee9e83 100644
--- a/topi/python/topi/vision/nms.py
+++ b/topi/python/topi/vision/nms.py
@@ -257,9 +257,12 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size,
num_anchors
         Batch size. We need to pass it in since hybrid script doesn't support
         binding variable to symbolic dim.
 
-    max_output_size : tvm.tir.const
+    num_anchors: tvm.tir.IntImm or tvm.tir.Var
+        The number of anchors.
+
+    max_output_size : tvm.te.Tensor
         Max number of output valid boxes for each instance.
-        By default all valid boxes are returned.
+        Return all valid boxes if max_output_size < 0.
 
     iou_threshold : tvm.tir.const
         Overlapping(IoU) threshold to suppress object with smaller score.
@@ -300,7 +303,7 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors
 
     box_data_length = data.shape[2]
 
-    # box_indices is the expected value, similar to TF & ONNX
+    # box_indices is the expected indices of boxes
     box_indices = output_tensor((batch_size, num_anchors), sorted_index.dtype)
     output = output_tensor((batch_size,
                             num_anchors,
@@ -326,13 +329,33 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size,
num_anchors
             # Apply nms
             box_start_idx = coord_start
             batch_idx = i
+            num_valid_boxes = 0
 
             for j in range(valid_count[i]):
-                if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j,
id_index] >= 0):
+                if num_valid_boxes == max_output_size:
+                    for k in range(box_data_length):
+                        output[i, j, k] = -one
+                    box_indices[i, j] = -1
+
+                elif output[i, j, score_index] > 0:
                     box_a_idx = j
-                    for k in parallel(valid_count[i]):
+                    is_valid_box = 1
+
+                    # a_l: left, a_t: top, a_r: right, a_b: bottom
+                    a_l = min(output[batch_idx, box_a_idx, box_start_idx],
+                              output[batch_idx, box_a_idx, box_start_idx + 2])
+                    a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1],
+                              output[batch_idx, box_a_idx, box_start_idx + 3])
+                    a_r = max(output[batch_idx, box_a_idx, box_start_idx],
+                              output[batch_idx, box_a_idx, box_start_idx + 2])
+                    a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1],
+                              output[batch_idx, box_a_idx, box_start_idx + 3])
+
+                    # check if current box j is valid by calculating iou with
+                    # all existing valid boxes
+                    for k in range(j):
                         check_iou = 0
-                        if k > j and output[i, k, score_index] > 0 \
+                        if is_valid_box == 1 and k < j and output[i, k, score_index] >
0 \
                                 and (id_index < 0 or output[i, k, id_index] >= 0):
                             if force_suppress:
                                 check_iou = 1
@@ -340,16 +363,6 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size,
num_anchors
                                 check_iou = 1
 
                         if check_iou > 0:
-                            # a_l: left, a_t: top, a_r: right, a_b: bottom
-                            a_l = min(output[batch_idx, box_a_idx, box_start_idx],
-                                      output[batch_idx, box_a_idx, box_start_idx + 2])
-                            a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1],
-                                      output[batch_idx, box_a_idx, box_start_idx + 3])
-                            a_r = max(output[batch_idx, box_a_idx, box_start_idx],
-                                      output[batch_idx, box_a_idx, box_start_idx + 2])
-                            a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1],
-                                      output[batch_idx, box_a_idx, box_start_idx + 3])
-
                             box_b_idx = k
 
                             # b_l: left, b_t: top, b_r: right, b_b: bottom
@@ -377,10 +390,14 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size,
num_anchors
                             iou = zero if u <= zero else area / u
 
                             if iou >= iou_threshold:
-                                output[i, k, score_index] = -one
-                                if id_index >= 0:
-                                    output[i, k, id_index] = -one
-                                box_indices[i, k] = -1
+                                is_valid_box = 0
+
+                    if is_valid_box == 0:
+                        for k in range(box_data_length):
+                            output[i, j, k] = -one
+                        box_indices[i, j] = -1
+                    else:
+                        num_valid_boxes += 1
 
         else:
             for j in parallel(valid_count[i]):
@@ -394,18 +411,6 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size,
num_anchors
                 output[i, j + valid_count[i], k] = -one
             box_indices[i, j + valid_count[i]] = -1
 
-        # Only return max_output_size valid boxes
-        num_valid_boxes = 0
-        if max_output_size > 0:
-            for j in range(valid_count[i]):
-                if output[i, j, 0] >= zero:
-                    if num_valid_boxes == max_output_size:
-                        for k in range(box_data_length):
-                            output[i, j, k] = -one
-                        box_indices[i, j] = -1
-                    else:
-                        num_valid_boxes += 1
-
         if return_indices:
             for j in range(valid_count[i]):
                 idx = box_indices[i, j]
@@ -432,9 +437,9 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
     indices : tvm.te.Tensor
         2-D tensor with shape [batch_size, num_anchors].
 
-    max_output_size : optional, int
+    max_output_size : optional, int or tvm.te.Tensor
         Max number of output valid boxes for each instance.
-        By default all valid boxes are returned.
+        Return all valid boxes if the value of max_output_size is less than 0.
 
     iou_threshold : optional, float
         Non-maximum suppression threshold.
@@ -494,17 +499,20 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
+    if isinstance(max_output_size, int):
+        max_output_size = tvm.tir.const(max_output_size, dtype="int32")
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
     score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis])
     sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+
     out, box_indices = hybrid_nms(data,
                                   sort_tensor,
                                   valid_count,
                                   indices,
                                   batch_size,
                                   num_anchors,
-                                  tvm.tir.const(max_output_size, dtype="int32"),
+                                  max_output_size,
                                   tvm.tir.const(iou_threshold, dtype=data.dtype),
                                   tvm.tir.const(force_suppress, dtype="bool"),
                                   tvm.tir.const(top_k, dtype="int32"),
diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py
index d2331ee..b74e193 100644
--- a/topi/tests/python/test_topi_vision.py
+++ b/topi/tests/python/test_topi_vision.py
@@ -132,7 +132,7 @@ def test_get_valid_counts():
     verify_get_valid_counts((16, 500, 5), 0.95, -1, 1)
 
 
-def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
+def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
max_output_size,
                                iou_threshold, force_suppress, top_k, coord_start, score_index,
id_index):
     dshape = np_data.shape
     batch, num_anchors, _ = dshape
@@ -149,11 +149,11 @@ def verify_non_max_suppression(np_data, np_valid_count, np_indices,
np_result, n
         print("Running on target: %s" % device)
         with tvm.target.create(device):
             fcompute, fschedule = topi.testing.dispatch(device, _nms_implement)
-            out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress,
top_k,
-                           coord_start=coord_start, score_index=score_index, id_index=id_index,
+            out = fcompute(data, valid_count, indices, max_output_size, iou_threshold, force_suppress,
+                           top_k, coord_start=coord_start, score_index=score_index, id_index=id_index,
                            return_indices=False)
-            indices_out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress,
top_k,
-                                   coord_start=coord_start, score_index=score_index, id_index=id_index,
+            indices_out = fcompute(data, valid_count, indices, max_output_size, iou_threshold,
force_suppress,
+                                   top_k, coord_start=coord_start, score_index=score_index,
id_index=id_index,
                                    return_indices=True)
             s = fschedule(out)
             indices_s = fschedule(indices_out)
@@ -186,23 +186,27 @@ def test_non_max_suppression():
                          [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
     np_valid_count = np.array([4]).astype("int32")
     np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32")
+    max_output_size = -1
     np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
                            [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
                            [-1, -1, -1, -1, -1, -1]]])
     np_indices_result = np.array([[3, 0, -1, -1, -1]])
 
-    verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
0.7, True, 2, 2, 1, 0)
+    verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
+                               max_output_size, 0.7, True, 2, 2, 1, 0)
 
     np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80],
                          [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79],
                          [0.5, 100, 60, 70, 110]]]).astype("float32")
     np_valid_count = np.array([4]).astype("int32")
     np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32")
+    max_output_size = 2
     np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45],
                            [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1],
                            [-1, -1, -1, -1, -1]]])
     np_indices_result = np.array([[3, 0, -1, -1, -1]])
-    verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
0.7, False, 2, 1, 0, -1)
+    verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result,
+                               max_output_size, 0.7, False, 2, 1, 0, -1)
 
 
 def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5,
0.5), clip=False):


Mime
View raw message