spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-15888] [SQL] fix Python UDF with aggregate
Date Wed, 15 Jun 2016 20:38:09 GMT
Repository: spark
Updated Branches:
  refs/heads/master 279bd4aa5 -> 5389013ac


[SPARK-15888] [SQL] fix Python UDF with aggregate

## What changes were proposed in this pull request?

After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of
aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after
aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate,
create a Project on top of Aggregate.

## How was this patch tested?

Added regression tests. The plan of added test query looks like this:
```
== Parsed Logical Plan ==
'Project [<lambda>('k, 's) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6)
as bigint)) AS s#22L]
   +- LogicalRDD [key#5L, value#6]

== Analyzed Logical Plan ==
t: int
Project [<lambda>(k#17, s#22L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6)
as bigint)) AS s#22L]
   +- LogicalRDD [key#5L, value#6]

== Optimized Logical Plan ==
Project [<lambda>(agg#29, agg#30L) AS t#26]
+- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6)
as bigint)) AS agg#30L]
   +- LogicalRDD [key#5L, value#6]

== Physical Plan ==
*Project [pythonUDF0#37 AS t#26]
+- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37]
   +- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6)
as bigint))], output=[agg#29,agg#30L])
      +- Exchange hashpartitioning(<lambda>(key#5L)#31, 200)
         +- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35
as bigint))], output=[<lambda>(key#5L)#31,sum#33L])
            +- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L,
value#6, pythonUDF0#34, pythonUDF1#35]
               +- Scan ExistingRDD[key#5L,value#6]
```

Author: Davies Liu <davies@databricks.com>

Closes #13682 from davies/fix_py_udf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5389013a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5389013a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5389013a

Branch: refs/heads/master
Commit: 5389013acc99367729dfc6deeb2cecc9edd1e24c
Parents: 279bd4a
Author: Davies Liu <davies@databricks.com>
Authored: Wed Jun 15 13:38:04 2016 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Wed Jun 15 13:38:04 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 10 ++-
 .../spark/sql/execution/SparkOptimizer.scala    |  6 +-
 .../execution/python/BatchEvalPythonExec.scala  |  2 +
 .../execution/python/ExtractPythonUDFs.scala    | 70 +++++++++++++++++---
 4 files changed, 77 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d5d691..c631ad8 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_udf_with_aggregate_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key",
"value"])
-        from pyspark.sql.functions import udf, col
+        from pyspark.sql.functions import udf, col, sum
         from pyspark.sql.types import BooleanType
 
         my_filter = udf(lambda a: a == 1, BooleanType())
         sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
         self.assertEqual(sel.collect(), [Row(key=1)])
 
+        my_copy = udf(lambda x: x, IntegerType())
+        my_add = udf(lambda a, b: int(a + b), IntegerType())
+        my_strlen = udf(lambda x: len(x), IntegerType())
+        sel = df.groupBy(my_copy(col("key")).alias("k"))\
+            .agg(sum(my_strlen(col("value"))).alias("s"))\
+            .select(my_add(col("k"), col("s")).alias("t"))
+        self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)

http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 08b2d7f..12a10cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.ExperimentalMethods
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
 import org.apache.spark.sql.internal.SQLConf
 
 class SparkOptimizer(
@@ -28,6 +29,7 @@ class SparkOptimizer(
     experimentalMethods: ExperimentalMethods)
   extends Optimizer(catalog, conf) {
 
-  override def batches: Seq[Batch] = super.batches :+ Batch(
-    "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
+  override def batches: Seq[Batch] = super.batches :+
+    Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
+    Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations:
_*)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 061d7c7..d9bf4d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute],
chi
 
   def children: Seq[SparkPlan] = child :: Nil
 
+  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+
   private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression])
= {
     udf.children match {
       case Seq(u: PythonUDF) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/5389013a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index ab19236..668470e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -18,12 +18,68 @@
 package org.apache.spark.sql.execution.python
 
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.SparkPlan
 
+
+/**
+ * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression
or
+ * grouping key, evaluate them after aggregate.
+ */
+private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
+
+  /**
+   * Returns whether the expression could only be evaluated within aggregate.
+   */
+  private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
+    e.isInstanceOf[AggregateExpression] ||
+      agg.groupingExpressions.exists(_.semanticEquals(e))
+  }
+
+  private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
+    expr.find {
+      e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
+    }.isDefined
+  }
+
+  private def extract(agg: Aggregate): LogicalPlan = {
+    val projList = new ArrayBuffer[NamedExpression]()
+    val aggExpr = new ArrayBuffer[NamedExpression]()
+    agg.aggregateExpressions.foreach { expr =>
+      if (hasPythonUdfOverAggregate(expr, agg)) {
+        // Python UDF can only be evaluated after aggregate
+        val newE = expr transformDown {
+          case e: Expression if belongAggregate(e, agg) =>
+            val alias = e match {
+              case a: NamedExpression => a
+              case o => Alias(e, "agg")()
+            }
+            aggExpr += alias
+            alias.toAttribute
+        }
+        projList += newE.asInstanceOf[NamedExpression]
+      } else {
+        aggExpr += expr
+        projList += expr.toAttribute
+      }
+    }
+    // There is no Python UDF over aggregate expression
+    Project(projList, agg.copy(aggregateExpressions = aggExpr))
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+    case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg))
=>
+      extract(agg)
+  }
+}
+
+
 /**
  * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
  * alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
   }
 
   /**
-   * Extract all the PythonUDFs from the current operator.
+   * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
    */
-  def extract(plan: SparkPlan): SparkPlan = {
+  private def extract(plan: SparkPlan): SparkPlan = {
     val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+      // ignore the PythonUDF that come from second/third aggregate, which is not used
+      .filter(udf => udf.references.subsetOf(plan.inputSet))
     if (udfs.isEmpty) {
       // If there aren't any, we are done.
       plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
       // Other cases are disallowed as they are ambiguous or would require a cartesian
       // product.
       udfs.filterNot(attributeMap.contains).foreach { udf =>
-        if (udf.references.subsetOf(plan.inputSet)) {
-          sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
-        } else {
-          sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
-        }
+        sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
       }
 
       val rewritten = plan.transformExpressions {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message