spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle
Date Mon, 16 Oct 2017 05:48:25 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.2 6b6761e8f -> 0f060a251


[SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessary shuffle

`ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary
shuffle.

Added Jenkins test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19501 from viirya/SPARK-22223.

(cherry picked from commit 0ae96495dedb54b3b6bae0bd55560820c5ca29a2)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f060a25
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f060a25
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f060a25

Branch: refs/heads/branch-2.2
Commit: 0f060a251fb17ccc94bc41f9ea9af2fa39539ff9
Parents: 6b6761e
Author: Liang-Chi Hsieh <viirya@gmail.com>
Authored: Mon Oct 16 13:37:58 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Mon Oct 16 13:47:59 2017 +0800

----------------------------------------------------------------------
 .../aggregate/ObjectHashAggregateExec.scala     |  2 ++
 .../spark/sql/DataFrameAggregateSuite.scala     | 31 ++++++++++++++++++++
 2 files changed, 33 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f060a25/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index b53521b..b69500d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -94,6 +94,8 @@ case class ObjectHashAggregateExec(
     }
   }
 
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
     val numOutputRows = longMetric("numOutputRows")
     val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold

http://git-wip-us.apache.org/repos/asf/spark/blob/0f060a25/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f50c0cf..87aabf7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.execution.aggregate.{ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -577,4 +579,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
       spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"),
       Seq(Row(3, 4, 9)))
   }
+
+  test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
+    withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+      val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a",
"b", "c")
+        .repartition(col("a"))
+
+      val objHashAggDF = df
+        .withColumn("d", expr("(a, b, c)"))
+        .groupBy("a", "b").agg(collect_list("d").as("e"))
+        .withColumn("f", expr("(b, e)"))
+        .groupBy("a").agg(collect_list("f").as("g"))
+      val aggPlan = objHashAggDF.queryExecution.executedPlan
+
+      val sortAggPlans = aggPlan.collect {
+        case sortAgg: SortAggregateExec => sortAgg
+      }
+      assert(sortAggPlans.isEmpty)
+
+      val objHashAggPlans = aggPlan.collect {
+        case objHashAgg: ObjectHashAggregateExec => objHashAgg
+      }
+      assert(objHashAggPlans.nonEmpty)
+
+      val exchangePlans = aggPlan.collect {
+        case shuffle: ShuffleExchange => shuffle
+      }
+      assert(exchangePlans.length == 1)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message