spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-16175] [PYSPARK] handle None for UDT
Date Tue, 28 Jun 2016 21:10:02 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 5c9555e11 -> 43bd612f3


[SPARK-16175] [PYSPARK] handle None for UDT

## What changes were proposed in this pull request?

Scala UDT will bypass all the null and will not pass them into serialize() and deserialize()
of UDT, this PR update the Python UDT to do this as well.

## How was this patch tested?

Added tests.

Author: Davies Liu <davies@databricks.com>

Closes #13878 from davies/udt_null.

(cherry picked from commit 35438fb0ad3bcda5c5a3a0ccde1a620699d012db)
Signed-off-by: Davies Liu <davies.liu@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43bd612f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43bd612f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43bd612f

Branch: refs/heads/branch-2.0
Commit: 43bd612f35490c11a76d5379d723ba65f7afbefd
Parents: 5c9555e
Author: Davies Liu <davies@databricks.com>
Authored: Tue Jun 28 14:09:38 2016 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Tue Jun 28 14:09:58 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 11 +++++++++++
 python/pyspark/sql/types.py |  7 +++++--
 2 files changed, 16 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/43bd612f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index f863485..a8ca386 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -575,6 +575,17 @@ class SQLTests(ReusedPySparkTestCase):
         _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
         self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
 
+    def test_udt_with_none(self):
+        df = self.spark.range(0, 10, 1, 1)
+
+        def myudf(x):
+            if x > 0:
+                return PythonOnlyPoint(float(x), float(x))
+
+        self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+        rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+        self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+
     def test_infer_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))

http://git-wip-us.apache.org/repos/asf/spark/blob/43bd612f/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index f0b56be..a367987 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -648,10 +648,13 @@ class UserDefinedType(DataType):
         return cls._cached_sql_type
 
     def toInternal(self, obj):
-        return self._cachedSqlType().toInternal(self.serialize(obj))
+        if obj is not None:
+            return self._cachedSqlType().toInternal(self.serialize(obj))
 
     def fromInternal(self, obj):
-        return self.deserialize(self._cachedSqlType().fromInternal(obj))
+        v = self._cachedSqlType().fromInternal(obj)
+        if v is not None:
+            return self.deserialize(v)
 
     def serialize(self, obj):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message