Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 61588200B2B for ; Tue, 28 Jun 2016 23:10:04 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 60275160A56; Tue, 28 Jun 2016 21:10:04 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id A79E3160A28 for ; Tue, 28 Jun 2016 23:10:03 +0200 (CEST) Received: (qmail 98653 invoked by uid 500); 28 Jun 2016 21:10:02 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 98640 invoked by uid 99); 28 Jun 2016 21:10:02 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 28 Jun 2016 21:10:02 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id BE7CEE08FE; Tue, 28 Jun 2016 21:10:02 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: davies@apache.org To: commits@spark.apache.org Message-Id: <8dde9d7786474c258ed06c7917312dad@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-16175] [PYSPARK] handle None for UDT Date: Tue, 28 Jun 2016 21:10:02 +0000 (UTC) archived-at: Tue, 28 Jun 2016 21:10:04 -0000 Repository: spark Updated Branches: refs/heads/branch-2.0 5c9555e11 -> 43bd612f3 [SPARK-16175] [PYSPARK] handle None for UDT ## What changes were proposed in this pull request? Scala UDT will bypass all the null and will not pass them into serialize() and deserialize() of UDT, this PR update the Python UDT to do this as well. ## How was this patch tested? Added tests. Author: Davies Liu Closes #13878 from davies/udt_null. (cherry picked from commit 35438fb0ad3bcda5c5a3a0ccde1a620699d012db) Signed-off-by: Davies Liu Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43bd612f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43bd612f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43bd612f Branch: refs/heads/branch-2.0 Commit: 43bd612f35490c11a76d5379d723ba65f7afbefd Parents: 5c9555e Author: Davies Liu Authored: Tue Jun 28 14:09:38 2016 -0700 Committer: Davies Liu Committed: Tue Jun 28 14:09:58 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 11 +++++++++++ python/pyspark/sql/types.py | 7 +++++-- 2 files changed, 16 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/43bd612f/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f863485..a8ca386 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -575,6 +575,17 @@ class SQLTests(ReusedPySparkTestCase): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_udt_with_none(self): + df = self.spark.range(0, 10, 1, 1) + + def myudf(x): + if x > 0: + return PythonOnlyPoint(float(x), float(x)) + + self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) + rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] + self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) http://git-wip-us.apache.org/repos/asf/spark/blob/43bd612f/python/pyspark/sql/types.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f0b56be..a367987 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -648,10 +648,13 @@ class UserDefinedType(DataType): return cls._cached_sql_type def toInternal(self, obj): - return self._cachedSqlType().toInternal(self.serialize(obj)) + if obj is not None: + return self._cachedSqlType().toInternal(self.serialize(obj)) def fromInternal(self, obj): - return self.deserialize(self._cachedSqlType().fromInternal(obj)) + v = self._cachedSqlType().fromInternal(obj) + if v is not None: + return self.deserialize(v) def serialize(self, obj): """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org