mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [incubator-mxnet] branch master updated: [MXNET-257] Do not copy when casting as same type (#10347)
Date Sun, 01 Apr 2018 23:11:51 GMT
This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0291f95  [MXNET-257] Do not copy when casting as same type (#10347)
0291f95 is described below

commit 0291f955035d2f912b17831b4eebdeb96b0ff29d
Author: ThomasDelteil <thomas.delteil1@gmail.com>
AuthorDate: Sun Apr 1 16:11:47 2018 -0700

    [MXNET-257] Do not copy when casting as same type (#10347)
    
    * do not copy on same type when casting
    
    * update to have a copy flag
    
    * adding sparse array
---
 python/mxnet/ndarray/ndarray.py              | 15 +++++++++++--
 python/mxnet/ndarray/sparse.py               | 10 ++++++++-
 tests/python/unittest/test_ndarray.py        | 32 ++++++++++++++++++++++++++++
 tests/python/unittest/test_sparse_ndarray.py | 32 ++++++++++++++++++++++++++++
 4 files changed, 86 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7dc2acf..3febf09 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1843,18 +1843,25 @@ fixed-size items.
             raise ValueError("The current array is not a scalar")
         return self.asnumpy()[0]
 
-    def astype(self, dtype):
+    def astype(self, dtype, copy=True):
         """Returns a copy of the array after casting to a specified type.
 
         Parameters
         ----------
         dtype : numpy.dtype or str
             The type of the returned array.
+        copy : bool
+            Default `True`. By default, astype always returns a newly
+            allocated ndarray on the same context. If this is set to
+            False, and the dtype requirement is satisfied,
+            the input ndarray is returned instead of a copy.
 
         Returns
         -------
         NDArray, CSRNDArray or RowSparseNDArray
-            The copied array after casting to the specified type.
+            The copied array after casting to the specified type, or
+            the same array if copy=False and dtype is the same as the input
+            array.
 
         Examples
         --------
@@ -1863,6 +1870,10 @@ fixed-size items.
         >>> y.dtype
         <type 'numpy.int32'>
         """
+
+        if not copy and np.dtype(dtype) == self.dtype:
+            return self
+
         res = empty(self.shape, ctx=self.context, dtype=dtype)
         self.copyto(res)
         return res
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index da33f58..7ea4023 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -179,12 +179,17 @@ class BaseSparseNDArray(NDArray):
         """
         return self.tostype('default').asnumpy()
 
-    def astype(self, dtype):
+    def astype(self, dtype, copy=True):
         """Returns a copy of the array after casting to a specified type.
         Parameters
         ----------
         dtype : numpy.dtype or str
             The type of the returned array.
+        copy : bool
+            Default `True`. By default, astype always returns a newly
+            allocated ndarray on the same context. If this is set to
+            False, and the dtype requirement is satisfied,
+            the input ndarray is returned instead of a copy.
         Examples
         --------
         >>> x = mx.nd.sparse.zeros('row_sparse', (2,3), dtype='float32')
@@ -192,6 +197,9 @@ class BaseSparseNDArray(NDArray):
         >>> y.dtype
         <type 'numpy.int32'>
         """
+        if not copy and np.dtype(dtype) == self.dtype:
+            return self
+
         res = zeros(shape=self.shape, ctx=self.context,
                     dtype=dtype, stype=self.stype)
         self.copyto(res)
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 16f08b0..c9790f8 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1134,6 +1134,38 @@ def test_assign_a_row_to_ndarray():
     a_nd[0, :] = a_nd[1]
     assert same(a_np, a_nd.asnumpy())
 
+@with_seed()
+def test_ndarray_astype():
+    x = mx.nd.zeros((2, 3), dtype='int32')
+    y = x.astype('float32')
+    assert (y.dtype==np.float32)
+    # Test that a new ndarray has been allocated
+    assert (id(x) != id(y))
+
+    x = mx.nd.zeros((2, 3), dtype='int32')
+    y = x.astype('float32', copy=False)
+    assert (y.dtype==np.float32)
+    # Test that a new ndarray has been allocated
+    assert (id(x) != id(y))
+
+    x = mx.nd.zeros((2, 3), dtype='int32')
+    y = x.astype('int32')
+    assert (y.dtype==np.int32)
+    # Test that a new ndarray has been allocated
+    # even though they have same dtype
+    assert (id(x) != id(y))
+
+    # Test that a new ndarray has not been allocated
+    x = mx.nd.zeros((2, 3), dtype='int32')
+    y = x.astype('int32', copy=False)
+    assert (id(x) == id(y))
+    
+    # Test the string version 'int32'
+    # has the same behaviour as the np.int32
+    x = mx.nd.zeros((2, 3), dtype='int32')
+    y = x.astype(np.int32, copy=False)
+    assert (id(x) == id(y))
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 182e70c..169ed89 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -23,6 +23,7 @@ from common import setup_module, with_seed, random_seed
 from mxnet.base import mx_real_t
 from numpy.testing import assert_allclose
 import numpy.random as rnd
+import numpy as np
 from common import assertRaises
 from mxnet.ndarray.sparse import RowSparseNDArray, CSRNDArray
 
@@ -443,6 +444,37 @@ def test_sparse_nd_astype():
         assert(y.dtype == np.int32), y.dtype
 
 
+@with_seed()
+def test_sparse_nd_astype_copy():
+    stypes = ['row_sparse', 'csr']
+    for stype in stypes:
+        x = mx.nd.zeros(shape=rand_shape_2d(), stype=stype, dtype='int32')
+        y = x.astype('float32')
+        assert (y.dtype == np.float32)
+        # Test that a new ndarray has been allocated
+        assert (id(x) != id(y))
+
+        y = x.astype('float32', copy=False)
+        assert (y.dtype == np.float32)
+        # Test that a new ndarray has been allocated
+        assert (id(x) != id(y))
+
+        y = x.astype('int32')
+        assert (y.dtype == np.int32)
+        # Test that a new ndarray has been allocated
+        # even though they have same dtype
+        assert (id(x) != id(y))
+
+        # Test that a new ndarray has not been allocated
+        y = x.astype('int32', copy=False)
+        assert (id(x) == id(y))
+
+        # Test the string version 'int32'
+        # has the same behaviour as the np.int32
+        y = x.astype(np.int32, copy=False)
+        assert (id(x) == id(y))
+
+
 @with_seed(0)
 def test_sparse_nd_pickle():
     repeat = 1

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.

Mime
View raw message