mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [incubator-mxnet] branch master updated: Add asscipy feature (#8362)
Date Sun, 05 Nov 2017 00:10:13 GMT
This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 63d5394  Add asscipy feature (#8362)
63d5394 is described below

commit 63d5394c41ac1eaadbb8133e1da0c3a559e4e84a
Author: Anirudh Subramanian <anirudh2290@gmail.com>
AuthorDate: Sat Nov 4 17:10:10 2017 -0700

    Add asscipy feature (#8362)
    
    * Add asscipy support and coo format support
    
    * Comment misalignment change
    
    * Add documentation for Sparse NDarray
    
    * Change comment
    
    * Adding comments and support for dtype
    
    * Modifying tests
    
    * Add spsp None check
    
    * Fix lint
---
 docs/api/python/ndarray/sparse.md            |  1 +
 python/mxnet/ndarray/sparse.py               | 66 +++++++++++++++++++++++++---
 tests/python/unittest/test_sparse_ndarray.py | 33 +++++++++++---
 3 files changed, 89 insertions(+), 11 deletions(-)

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index 0424405..9b742f4 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -124,6 +124,7 @@ We summarize the interface for each class in the following sections.
     CSRNDArray.copyto
     CSRNDArray.as_in_context
     CSRNDArray.asnumpy
+    CSRNDArray.asscipy
     CSRNDArray.asscalar
     CSRNDArray.astype
     CSRNDArray.tostype
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index a31ec01..45a269a 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -492,6 +492,27 @@ class CSRNDArray(BaseSparseNDArray):
         else:
             raise TypeError('copyto does not support type ' + str(type(other)))
 
+    def asscipy(self):
+        """Returns a ``scipy.sparse.csr.csr_matrix`` object with value copied from this array
+
+        Examples
+        --------
+        >>> x = mx.nd.sparse.zeros('csr', (2,3))
+        >>> y = x.asscipy()
+        >>> type(y)
+        <type 'scipy.sparse.csr.csr_matrix'>
+        >>> y
+        <2x3 sparse matrix of type '<type 'numpy.float32'>'
+        with 0 stored elements in Compressed Sparse Row format>
+        """
+        data = self.data.asnumpy()
+        indices = self.indices.asnumpy()
+        indptr = self.indptr.asnumpy()
+        if not spsp:
+            raise ImportError("scipy is not available. \
+                               Please check if the scipy python bindings are installed.")
+        return spsp.csr_matrix((data, indices, indptr), shape=self.shape, dtype=self.dtype)
+
 # pylint: disable=abstract-method
 class RowSparseNDArray(BaseSparseNDArray):
     """A sparse representation of a set of NDArray row slices at given indices.
@@ -798,10 +819,30 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
             The default dtype is ``data.dtype`` if ``data`` is an NDArray or numpy.ndarray,
\
             float32 otherwise.
 
+    - csr_matrix((data, (row, col)))
+        to construct a CSRNDArray based on the COOrdinate format \
+        using three seperate arrays, \
+        where ``row[i]`` is the row index of the element, \
+        ``col[i]`` is the column index of the element \
+        and ``data[i]`` is the data corresponding to the element. All the missing \
+        elements in the input are taken to be zeroes.
+            - **data** (*array_like*) - An object exposing the array interface, which \
+            holds all the non-zero entries of the matrix in COO format.
+            - **row** (*array_like*) - An object exposing the array interface, which \
+            stores the row index for each non zero element in ``data``.
+            - **col** (*array_like*) - An object exposing the array interface, which \
+            stores the col index for each non zero element in ``data``.
+            - **shape** (*tuple of int, optional*) - The shape of the array. The default
\
+            shape is inferred from the ``row`` and ``col`` arrays.
+            - **ctx** (*Context, optional*) - Device context \
+            (default is the current default context).
+            - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array.
\
+            The default dtype is float32.
+
     Parameters
     ----------
-    arg1: NDArray, CSRNDArray, numpy.ndarray, scipy.sparse.csr.csr_matrix, tuple of int or
tuple \
-    of array_like
+    arg1: tuple of int, tuple of array_like, array_like, CSRNDArray, scipy.sparse.csr_matrix,
\
+    scipy.sparse.coo_matrix, tuple of int or tuple of array_like
         The argument to help instantiate the csr matrix. See above for further details.
     shape : tuple of int, optional
         The shape of the csr matrix.
@@ -832,9 +873,24 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
     if isinstance(arg1, tuple):
         arg_len = len(arg1)
         if arg_len == 2:
-            # empty matrix with shape
-            _check_shape(arg1, shape)
-            return empty('csr', arg1, ctx=ctx, dtype=dtype)
+            # construct a sparse csr matrix from
+            # scipy coo matrix if input format is coo
+            if isinstance(arg1[1], tuple) and len(arg1[1]) == 2:
+                data, (row, col) = arg1
+                if isinstance(data, NDArray):
+                    data = data.asnumpy()
+                if isinstance(row, NDArray):
+                    row = row.asnumpy()
+                if isinstance(col, NDArray):
+                    col = col.asnumpy()
+                coo = spsp.coo_matrix((data, (row, col)), shape=shape)
+                _check_shape(coo.shape, shape)
+                csr = coo.tocsr()
+                return array(csr, ctx=ctx, dtype=dtype)
+            else:
+                # empty matrix with shape
+                _check_shape(arg1, shape)
+                return empty('csr', arg1, ctx=ctx, dtype=dtype)
         elif arg_len == 3:
             # data, indices, indptr
             return _csr_matrix_from_definition(arg1[0], arg1[1], arg1[2], shape=shape,
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 0bbb5ff..03a4307 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -476,17 +476,38 @@ def test_create_csr():
         assert same(csr_created.indptr.asnumpy(), indptr.asnumpy())
         assert same(csr_created.indices.asnumpy(), indices.asnumpy())
         # verify csr matrix dtype and ctx is consistent from the ones provided
-        assert csr_created.dtype == dtype
-        assert csr_created.data.dtype == dtype
-        assert csr_created.context == Context.default_ctx
+        assert csr_created.dtype == dtype, (csr_created, dtype)
+        assert csr_created.data.dtype == dtype, (csr_created.data.dtype, dtype)
+        assert csr_created.context == Context.default_ctx, (csr_created.context, Context.default_ctx)
         csr_copy = mx.nd.array(csr_created)
         assert(same(csr_copy.asnumpy(), csr_created.asnumpy()))
 
+    def check_create_csr_from_coo(shape, density, dtype):
+        matrix = rand_ndarray(shape, 'csr', density)
+        sp_csr = matrix.asscipy()
+        sp_coo = sp_csr.tocoo()
+        csr_created = mx.nd.sparse.csr_matrix((sp_coo.data, (sp_coo.row, sp_coo.col)), shape=shape,
dtype=dtype)
+        assert csr_created.stype == 'csr'
+        assert same(csr_created.data.asnumpy(), sp_csr.data)
+        assert same(csr_created.indptr.asnumpy(), sp_csr.indptr)
+        assert same(csr_created.indices.asnumpy(), sp_csr.indices)
+        csr_copy = mx.nd.array(csr_created)
+        assert(same(csr_copy.asnumpy(), csr_created.asnumpy()))
+        # verify csr matrix dtype and ctx is consistent
+        assert csr_created.dtype == dtype, (csr_created.dtype, dtype)
+        assert csr_created.data.dtype == dtype, (csr_created.data.dtype, dtype)
+        assert csr_created.context == Context.default_ctx, (csr_created.context, Context.default_ctx)
+
     def check_create_csr_from_scipy(shape, density, f):
         def assert_csr_almost_equal(nd, sp):
             assert_almost_equal(nd.data.asnumpy(), sp.data)
             assert_almost_equal(nd.indptr.asnumpy(), sp.indptr)
             assert_almost_equal(nd.indices.asnumpy(), sp.indices)
+            sp_csr = nd.asscipy()
+            assert_almost_equal(sp_csr.data, sp.data)
+            assert_almost_equal(sp_csr.indptr, sp.indptr)
+            assert_almost_equal(sp_csr.indices, sp.indices)
+            assert(sp.dtype == sp_csr.dtype), (sp.dtype, sp_csr.dtype)
 
         try:
             import scipy.sparse as spsp
@@ -498,8 +519,8 @@ def test_create_csr():
             indptr = np.array([0, 2, 3, 7])
             indices = np.array([0, 2, 2, 0, 1, 2, 1])
             data = np.array([1, 2, 3, 4, 5, 6, 1])
-            non_canonical_csr = spsp.csr_matrix((data, indices, indptr), shape=(3, 3))
-            canonical_csr_nd = f(non_canonical_csr)
+            non_canonical_csr = spsp.csr_matrix((data, indices, indptr), shape=(3, 3), dtype=csr_nd.dtype)
+            canonical_csr_nd = f(non_canonical_csr, dtype=csr_nd.dtype)
             canonical_csr_sp = non_canonical_csr.copy()
             canonical_csr_sp.sum_duplicates()
             canonical_csr_sp.sort_indices()
@@ -514,10 +535,10 @@ def test_create_csr():
     for density in densities:
         shape = rand_shape_2d(dim0, dim1)
         check_create_csr_from_nd(shape, density, dtype)
+        check_create_csr_from_coo(shape, density, dtype)
         check_create_csr_from_scipy(shape, density, mx.nd.sparse.array)
         check_create_csr_from_scipy(shape, density, mx.nd.array)
 
-
 def test_create_row_sparse():
     dim0 = 50
     dim1 = 50

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <commits@mxnet.apache.org>'].

Mime
View raw message