sxjscience opened a new issue #17179: [Numpy] where does not support python scalar as the input
URL: https://github.com/apache/incubator-mxnet/issues/17179
```python
import mxnet as mx
mx.npx.set_np()
a = mx.np.sym.var('a')
a = mx.sym.var('a').as_np_ndarray()
mx.sym.np.where(a, a, 0)
```
Error message:
```
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-5-e7a6fe89b743> in <module>
----> 1 mx.sym.np.where(a, a, 0)
~/mxnet/python/mxnet/symbol/numpy/_symbol.py in where(condition, x, y)
5501
5502 """
-> 5503 return _npi.where(condition, x, y, out=None)
5504
5505
~/mxnet/python/mxnet/symbol/register.py in where(condition, x, y, name, attr, out, **kwargs)
AssertionError: Argument y must be Symbol instances, but got 0
```
Also, the imperative case:
```python
import mxnet as mx
mx.npx.set_np()
mx.np.where(mx.np.ones((10, )), mx.np.ones((10, )), 0)
```
Error message:
```
AssertionError Traceback (most recent call last)
<ipython-input-14-bdddf3065582> in <module>
----> 1 mx.np.where(mx.np.ones((10, )), mx.np.ones((10, )), 0)
~/mxnet/python/mxnet/numpy/multiarray.py in where(condition, x, y)
7996 [ 0., 3., -1.]])
7997 """
-> 7998 return _mx_nd_np.where(condition, x, y)
7999
8000
~/mxnet/python/mxnet/ndarray/numpy/_op.py in where(condition, x, y)
6035 return nonzero(condition)
6036 else:
-> 6037 return _npi.where(condition, x, y, out=None)
6038
6039
~/mxnet/python/mxnet/ndarray/register.py in where(condition, x, y, out, name, **kwargs)
AssertionError: Argument y must have NDArray type, but got 0
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services
|