Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id B2AC0200B29 for ; Wed, 15 Jun 2016 22:38:11 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id B1652160A60; Wed, 15 Jun 2016 20:38:11 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id D5175160A19 for ; Wed, 15 Jun 2016 22:38:10 +0200 (CEST) Received: (qmail 79654 invoked by uid 500); 15 Jun 2016 20:38:10 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 79622 invoked by uid 99); 15 Jun 2016 20:38:10 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 15 Jun 2016 20:38:10 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id EBEBCDFF13; Wed, 15 Jun 2016 20:38:09 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: davies@apache.org To: commits@spark.apache.org Message-Id: <7c861e08ae344faeafaaa068f40c79c2@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-15888] [SQL] fix Python UDF with aggregate Date: Wed, 15 Jun 2016 20:38:09 +0000 (UTC) archived-at: Wed, 15 Jun 2016 20:38:11 -0000 Repository: spark Updated Branches: refs/heads/master 279bd4aa5 -> 5389013ac [SPARK-15888] [SQL] fix Python UDF with aggregate ## What changes were proposed in this pull request? After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate. ## How was this patch tested? Added regression tests. The plan of added test query looks like this: ``` == Parsed Logical Plan == 'Project [('k, 's) AS t#26] +- Aggregate [(key#5L)], [(key#5L) AS k#17, sum(cast((value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Analyzed Logical Plan == t: int Project [(k#17, s#22L) AS t#26] +- Aggregate [(key#5L)], [(key#5L) AS k#17, sum(cast((value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Optimized Logical Plan == Project [(agg#29, agg#30L) AS t#26] +- Aggregate [(key#5L)], [(key#5L) AS agg#29, sum(cast((value#6) as bigint)) AS agg#30L] +- LogicalRDD [key#5L, value#6] == Physical Plan == *Project [pythonUDF0#37 AS t#26] +- BatchEvalPython [(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37] +- *HashAggregate(key=[(key#5L)#31], functions=[sum(cast((value#6) as bigint))], output=[agg#29,agg#30L]) +- Exchange hashpartitioning((key#5L)#31, 200) +- *HashAggregate(key=[pythonUDF0#34 AS (key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[(key#5L)#31,sum#33L]) +- BatchEvalPython [(key#5L), (value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35] +- Scan ExistingRDD[key#5L,value#6] ``` Author: Davies Liu Closes #13682 from davies/fix_py_udf. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5389013a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5389013a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5389013a Branch: refs/heads/master Commit: 5389013acc99367729dfc6deeb2cecc9edd1e24c Parents: 279bd4a Author: Davies Liu Authored: Wed Jun 15 13:38:04 2016 -0700 Committer: Davies Liu Committed: Wed Jun 15 13:38:04 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 10 ++- .../spark/sql/execution/SparkOptimizer.scala | 6 +- .../execution/python/BatchEvalPythonExec.scala | 2 + .../execution/python/ExtractPythonUDFs.scala | 70 +++++++++++++++++--- 4 files changed, 77 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1d5d691..c631ad8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + from pyspark.sql.functions import udf, col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) self.assertEqual(sel.collect(), [Row(key=1)]) + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 08b2d7f..12a10cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate import org.apache.spark.sql.internal.SQLConf class SparkOptimizer( @@ -28,6 +29,7 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def batches: Seq[Batch] = super.batches :+ + Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) } http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 061d7c7..d9bf4d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi def children: Seq[SparkPlan] = child :: Nil + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index ab19236..668470e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -18,12 +18,68 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.SparkPlan + +/** + * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or + * grouping key, evaluate them after aggregate. + */ +private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { + + /** + * Returns whether the expression could only be evaluated within aggregate. + */ + private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { + e.isInstanceOf[AggregateExpression] || + agg.groupingExpressions.exists(_.semanticEquals(e)) + } + + private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { + expr.find { + e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + }.isDefined + } + + private def extract(agg: Aggregate): LogicalPlan = { + val projList = new ArrayBuffer[NamedExpression]() + val aggExpr = new ArrayBuffer[NamedExpression]() + agg.aggregateExpressions.foreach { expr => + if (hasPythonUdfOverAggregate(expr, agg)) { + // Python UDF can only be evaluated after aggregate + val newE = expr transformDown { + case e: Expression if belongAggregate(e, agg) => + val alias = e match { + case a: NamedExpression => a + case o => Alias(e, "agg")() + } + aggExpr += alias + alias.toAttribute + } + projList += newE.asInstanceOf[NamedExpression] + } else { + aggExpr += expr + projList += expr.toAttribute + } + } + // There is no Python UDF over aggregate expression + Project(projList, agg.copy(aggregateExpressions = aggExpr)) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) => + extract(agg) + } +} + + /** * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated * alone in a batch. @@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { } /** - * Extract all the PythonUDFs from the current operator. + * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - def extract(plan: SparkPlan): SparkPlan = { + private def extract(plan: SparkPlan): SparkPlan = { val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + // ignore the PythonUDF that come from second/third aggregate, which is not used + .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { // If there aren't any, we are done. plan @@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { // Other cases are disallowed as they are ambiguous or would require a cartesian // product. udfs.filterNot(attributeMap.contains).foreach { udf => - if (udf.references.subsetOf(plan.inputSet)) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") - } + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } val rewritten = plan.transformExpressions { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org