From commits-return-33881-archive-asf-public=cust-asf.ponee.io@spark.apache.org Thu Oct 4 03:44:02 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 2FF7418065B for ; Thu, 4 Oct 2018 03:44:02 +0200 (CEST) Received: (qmail 58232 invoked by uid 500); 4 Oct 2018 01:44:01 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 58223 invoked by uid 99); 4 Oct 2018 01:44:01 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 04 Oct 2018 01:44:01 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 7847FE0051; Thu, 4 Oct 2018 01:44:00 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: gurwls223@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement Date: Thu, 4 Oct 2018 01:44:00 +0000 (UTC) Repository: spark Updated Branches: refs/heads/branch-2.4 443d12dbb -> 0763b758d [SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance: ```python from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf("integer", PandasUDFType.GROUPED_AGG) def sum_udf(v): return v.sum() spark.udf.register("sum_udf", sum_udf) q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" spark.sql(q).show() ``` ``` +---+-----------+ | v2|sum_udf(v1)| +---+-----------+ | 1| 1| | 0| 5| +---+-----------+ ``` ## How was this patch tested? Manual test and unit test. Closes #22620 from HyukjinKwon/SPARK-25601. Authored-by: hyukjinkwon Signed-off-by: hyukjinkwon Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0763b758 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0763b758 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0763b758 Branch: refs/heads/branch-2.4 Commit: 0763b758de55fd14d7da4832d01b5713e582b257 Parents: 443d12d Author: hyukjinkwon Authored: Thu Oct 4 09:36:23 2018 +0800 Committer: hyukjinkwon Committed: Thu Oct 4 09:43:42 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 20 ++++++++++++++++++-- python/pyspark/sql/udf.py | 15 +++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0763b758/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 690035a..e991032 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5595,8 +5595,9 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' - 'SQL_SCALAR_PANDAS_UDF'): + with self.assertRaisesRegexp( + ValueError, + 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): @@ -6412,6 +6413,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + + sum_pandas_udf = pandas_udf( + lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + + self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf) + self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) + q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect())) + expected = [1, 5] + self.assertEqual(actual, expected) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, http://git-wip-us.apache.org/repos/asf/spark/blob/0763b758/python/pyspark/sql/udf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b..58f4e0d 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -298,6 +298,15 @@ class UDFRegistration(object): >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + >>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def sum_udf(v): + ... return v.sum() + ... + >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP + >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" + >>> spark.sql(q).collect() # doctest: +SKIP + [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)] + .. note:: Registration for a user-defined function (case 2.) was added from Spark 2.3.0. """ @@ -310,9 +319,11 @@ class UDFRegistration(object): "Invalid returnType: data type can not be specified when f is" "a user-defined function, but got %s." % returnType) if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_UDF]: + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") + "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF") register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org