spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject git commit: SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Date Thu, 10 Apr 2014 18:18:04 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.0 9ae80bf9b -> 2ac43add1


SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining

Author: Sandeep <sandeep@techaddict.me>

Closes #356 from techaddict/1428 and squashes the following commits:

3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead
of complaining

(cherry picked from commit 3bd312940e2f5250edaf3e88d6c23de25bb1d0a9)
Signed-off-by: Matei Zaharia <matei@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ac43add
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ac43add
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ac43add

Branch: refs/heads/branch-1.0
Commit: 2ac43add15b7002a5bdd75c785a73042b5082506
Parents: 9ae80bf
Author: Sandeep <sandeep@techaddict.me>
Authored: Thu Apr 10 11:17:41 2014 -0700
Committer: Matei Zaharia <matei@databricks.com>
Committed: Thu Apr 10 11:17:55 2014 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/_common.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2ac43add/python/pyspark/mllib/_common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 20a0e30..7ef251d 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,8 +15,9 @@
 # limitations under the License.
 #
 
-from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot,
shape
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot,
shape, complex, issubdtype
 from pyspark import SparkContext, RDD
+import numpy as np
 
 from pyspark.serializers import Serializer
 import struct
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
     return ar.copy()
 
 def _serialize_double_vector(v):
-    """Serialize a double vector into a mutually understood format."""
+    """Serialize a double vector into a mutually understood format.
+
+    >>> x = array([1,2,3])
+    >>> y = _deserialize_double_vector(_serialize_double_vector(x))
+    >>> array_equal(y, array([1.0, 2.0, 3.0]))
+    True
+    """
     if type(v) != ndarray:
         raise TypeError("_serialize_double_vector called on a %s; "
                 "wanted ndarray" % type(v))
+    """complex is only datatype that can't be converted to float64"""
+    if issubdtype(v.dtype, complex):
+        raise TypeError("_serialize_double_vector called on a %s; "
+                "wanted ndarray" % type(v))
     if v.dtype != float64:
-        raise TypeError("_serialize_double_vector called on an ndarray of %s; "
-                "wanted ndarray of float64" % v.dtype)
+        v = v.astype(float64)
     if v.ndim != 1:
         raise TypeError("_serialize_double_vector called on a %ddarray; "
                 "wanted a 1darray" % v.ndim)


Mime
View raw message