This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0291f95 [MXNET-257] Do not copy when casting as same type (#10347)
0291f95 is described below
commit 0291f955035d2f912b17831b4eebdeb96b0ff29d
Author: ThomasDelteil <thomas.delteil1@gmail.com>
AuthorDate: Sun Apr 1 16:11:47 2018 -0700
[MXNET-257] Do not copy when casting as same type (#10347)
* do not copy on same type when casting
* update to have a copy flag
* adding sparse array
---
python/mxnet/ndarray/ndarray.py | 15 +++++++++++--
python/mxnet/ndarray/sparse.py | 10 ++++++++-
tests/python/unittest/test_ndarray.py | 32 ++++++++++++++++++++++++++++
tests/python/unittest/test_sparse_ndarray.py | 32 ++++++++++++++++++++++++++++
4 files changed, 86 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7dc2acf..3febf09 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1843,18 +1843,25 @@ fixed-size items.
raise ValueError("The current array is not a scalar")
return self.asnumpy()[0]
- def astype(self, dtype):
+ def astype(self, dtype, copy=True):
"""Returns a copy of the array after casting to a specified type.
Parameters
----------
dtype : numpy.dtype or str
The type of the returned array.
+ copy : bool
+ Default `True`. By default, astype always returns a newly
+ allocated ndarray on the same context. If this is set to
+ False, and the dtype requirement is satisfied,
+ the input ndarray is returned instead of a copy.
Returns
-------
NDArray, CSRNDArray or RowSparseNDArray
- The copied array after casting to the specified type.
+ The copied array after casting to the specified type, or
+ the same array if copy=False and dtype is the same as the input
+ array.
Examples
--------
@@ -1863,6 +1870,10 @@ fixed-size items.
>>> y.dtype
<type 'numpy.int32'>
"""
+
+ if not copy and np.dtype(dtype) == self.dtype:
+ return self
+
res = empty(self.shape, ctx=self.context, dtype=dtype)
self.copyto(res)
return res
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index da33f58..7ea4023 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -179,12 +179,17 @@ class BaseSparseNDArray(NDArray):
"""
return self.tostype('default').asnumpy()
- def astype(self, dtype):
+ def astype(self, dtype, copy=True):
"""Returns a copy of the array after casting to a specified type.
Parameters
----------
dtype : numpy.dtype or str
The type of the returned array.
+ copy : bool
+ Default `True`. By default, astype always returns a newly
+ allocated ndarray on the same context. If this is set to
+ False, and the dtype requirement is satisfied,
+ the input ndarray is returned instead of a copy.
Examples
--------
>>> x = mx.nd.sparse.zeros('row_sparse', (2,3), dtype='float32')
@@ -192,6 +197,9 @@ class BaseSparseNDArray(NDArray):
>>> y.dtype
<type 'numpy.int32'>
"""
+ if not copy and np.dtype(dtype) == self.dtype:
+ return self
+
res = zeros(shape=self.shape, ctx=self.context,
dtype=dtype, stype=self.stype)
self.copyto(res)
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 16f08b0..c9790f8 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1134,6 +1134,38 @@ def test_assign_a_row_to_ndarray():
a_nd[0, :] = a_nd[1]
assert same(a_np, a_nd.asnumpy())
+@with_seed()
+def test_ndarray_astype():
+ x = mx.nd.zeros((2, 3), dtype='int32')
+ y = x.astype('float32')
+ assert (y.dtype==np.float32)
+ # Test that a new ndarray has been allocated
+ assert (id(x) != id(y))
+
+ x = mx.nd.zeros((2, 3), dtype='int32')
+ y = x.astype('float32', copy=False)
+ assert (y.dtype==np.float32)
+ # Test that a new ndarray has been allocated
+ assert (id(x) != id(y))
+
+ x = mx.nd.zeros((2, 3), dtype='int32')
+ y = x.astype('int32')
+ assert (y.dtype==np.int32)
+ # Test that a new ndarray has been allocated
+ # even though they have same dtype
+ assert (id(x) != id(y))
+
+ # Test that a new ndarray has not been allocated
+ x = mx.nd.zeros((2, 3), dtype='int32')
+ y = x.astype('int32', copy=False)
+ assert (id(x) == id(y))
+
+ # Test the string version 'int32'
+ # has the same behaviour as the np.int32
+ x = mx.nd.zeros((2, 3), dtype='int32')
+ y = x.astype(np.int32, copy=False)
+ assert (id(x) == id(y))
+
if __name__ == '__main__':
import nose
nose.runmodule()
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 182e70c..169ed89 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -23,6 +23,7 @@ from common import setup_module, with_seed, random_seed
from mxnet.base import mx_real_t
from numpy.testing import assert_allclose
import numpy.random as rnd
+import numpy as np
from common import assertRaises
from mxnet.ndarray.sparse import RowSparseNDArray, CSRNDArray
@@ -443,6 +444,37 @@ def test_sparse_nd_astype():
assert(y.dtype == np.int32), y.dtype
+@with_seed()
+def test_sparse_nd_astype_copy():
+ stypes = ['row_sparse', 'csr']
+ for stype in stypes:
+ x = mx.nd.zeros(shape=rand_shape_2d(), stype=stype, dtype='int32')
+ y = x.astype('float32')
+ assert (y.dtype == np.float32)
+ # Test that a new ndarray has been allocated
+ assert (id(x) != id(y))
+
+ y = x.astype('float32', copy=False)
+ assert (y.dtype == np.float32)
+ # Test that a new ndarray has been allocated
+ assert (id(x) != id(y))
+
+ y = x.astype('int32')
+ assert (y.dtype == np.int32)
+ # Test that a new ndarray has been allocated
+ # even though they have same dtype
+ assert (id(x) != id(y))
+
+ # Test that a new ndarray has not been allocated
+ y = x.astype('int32', copy=False)
+ assert (id(x) == id(y))
+
+ # Test the string version 'int32'
+ # has the same behaviour as the np.int32
+ y = x.astype(np.int32, copy=False)
+ assert (id(x) == id(y))
+
+
@with_seed(0)
def test_sparse_nd_pickle():
repeat = 1
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.
|