spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-15415][SQL] Fix BroadcastHint when autoBroadcastJoinThreshold is 0 or -1
Date Sun, 22 May 2016 06:01:25 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 9a08c9f1c -> fd7e83119


[SPARK-15415][SQL] Fix BroadcastHint when autoBroadcastJoinThreshold is 0 or -1

## What changes were proposed in this pull request?

This PR makes BroadcastHint more deterministic by using a special isBroadcastable property
instead of setting the sizeInBytes to 1.

See https://issues.apache.org/jira/browse/SPARK-15415

## How was this patch tested?

Added testcases to test if the broadcast hash join is included in the plan when the BroadcastHint
is supplied and also tests for propagation of the joins.

Author: Jurriaan Pruis <email@jurriaanpruis.nl>

Closes #13244 from jurriaan/broadcast-hint.

(cherry picked from commit 223f6339088434eb3590c2f42091a38f05f1e5db)
Signed-off-by: Reynold Xin <rxin@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd7e8311
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd7e8311
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd7e8311

Branch: refs/heads/branch-2.0
Commit: fd7e83119948187212ce75f4837cd7487d8a128a
Parents: 9a08c9f
Author: Jurriaan Pruis <email@jurriaanpruis.nl>
Authored: Sat May 21 23:01:14 2016 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Sat May 21 23:01:22 2016 -0700

----------------------------------------------------------------------
 .../catalyst/plans/logical/LogicalPlan.scala    |   3 +-
 .../sql/catalyst/plans/logical/Statistics.scala |   2 +-
 .../plans/logical/basicLogicalOperators.scala   |  29 ++++--
 .../spark/sql/execution/SparkStrategies.scala   |   3 +-
 .../execution/joins/BroadcastJoinSuite.scala    | 103 ++++++++++++++++---
 5 files changed, 114 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd7e8311/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 45ac126..4984f23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -313,7 +313,8 @@ abstract class UnaryNode extends LogicalPlan {
       // (product of children).
       sizeInBytes = 1
     }
-    Statistics(sizeInBytes = sizeInBytes)
+
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fd7e8311/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index 9ac4c3a..63f86ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -32,4 +32,4 @@ package org.apache.spark.sql.catalyst.plans.logical
  * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise
it
  *                    defaults to the product of children's `sizeInBytes`.
  */
-private[sql] case class Statistics(sizeInBytes: BigInt)
+private[sql] case class Statistics(sizeInBytes: BigInt, isBroadcastable: Boolean = false)

http://git-wip-us.apache.org/repos/asf/spark/blob/fd7e8311/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 732b0d7..bed48b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -163,7 +163,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
     val leftSize = left.statistics.sizeInBytes
     val rightSize = right.statistics.sizeInBytes
     val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
-    Statistics(sizeInBytes = sizeInBytes)
+    val isBroadcastable = left.statistics.isBroadcastable || right.statistics.isBroadcastable
+
+    Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable)
   }
 }
 
@@ -183,7 +185,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
       duplicateResolved
 
   override def statistics: Statistics = {
-    Statistics(sizeInBytes = left.statistics.sizeInBytes)
+    left.statistics.copy()
   }
 }
 
@@ -330,6 +332,16 @@ case class Join(
     case UsingJoin(_, _) => false
     case _ => resolvedExceptNatural
   }
+
+  override def statistics: Statistics = joinType match {
+    case LeftAnti | LeftSemi =>
+      // LeftSemi and LeftAnti won't ever be bigger than left
+      left.statistics.copy()
+    case _ =>
+      // make sure we don't propagate isBroadcastable in other joins, because
+      // they could explode the size.
+      super.statistics.copy(isBroadcastable = false)
+  }
 }
 
 /**
@@ -338,9 +350,8 @@ case class Join(
 case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 
-  // We manually set statistics of BroadcastHint to smallest value to make sure
-  // the plan wrapped by BroadcastHint will be considered to broadcast later.
-  override def statistics: Statistics = Statistics(sizeInBytes = 1)
+  // set isBroadcastable to true so the child will be broadcasted
+  override def statistics: Statistics = super.statistics.copy(isBroadcastable = true)
 }
 
 case class InsertIntoTable(
@@ -465,7 +476,7 @@ case class Aggregate(
 
   override def statistics: Statistics = {
     if (groupingExpressions.isEmpty) {
-      Statistics(sizeInBytes = 1)
+      super.statistics.copy(sizeInBytes = 1)
     } else {
       super.statistics
     }
@@ -638,7 +649,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends
UnaryN
   override lazy val statistics: Statistics = {
     val limit = limitExpr.eval().asInstanceOf[Int]
     val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
-    Statistics(sizeInBytes = sizeInBytes)
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 }
 
@@ -653,7 +664,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends
UnaryNo
   override lazy val statistics: Statistics = {
     val limit = limitExpr.eval().asInstanceOf[Int]
     val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
-    Statistics(sizeInBytes = sizeInBytes)
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 }
 
@@ -690,7 +701,7 @@ case class Sample(
     if (sizeInBytes == 0) {
       sizeInBytes = 1
     }
-    Statistics(sizeInBytes = sizeInBytes)
+    child.statistics.copy(sizeInBytes = sizeInBytes)
   }
 
   override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/fd7e8311/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3343039..664e7f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -92,7 +92,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
      * Matches a plan whose output should be small enough to be used in broadcast join.
      */
     private def canBroadcast(plan: LogicalPlan): Boolean = {
-      plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
+      plan.statistics.isBroadcastable ||
+        plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/spark/blob/fd7e8311/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 730ec43..e681b88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,9 +22,12 @@ import scala.reflect.ClassTag
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SparkSession}
+import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
 
 /**
  * Test various broadcast join operators.
@@ -33,7 +36,9 @@ import org.apache.spark.sql.functions._
  * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered
  * without serializing the hashed relation, which does not happen in local mode.
  */
-class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
+  import testImplicits._
+
   protected var spark: SparkSession = null
 
   /**
@@ -56,30 +61,100 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
   /**
    * Test whether the specified broadcast join updates the peak execution memory accumulator.
    */
-  private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
+  private def testBroadcastJoinPeak[T: ClassTag](name: String, joinType: String): Unit =
{
     AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) {
-      val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
-      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
-      // Comparison at the end is for broadcast left semi join
-      val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
-      val df3 = df1.join(broadcast(df2), joinExpression, joinType)
-      val plan =
-        EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
-      assert(plan.collect { case p: T => p }.size === 1)
+      val plan = testBroadcastJoin[T](joinType)
       plan.executeCollect()
     }
   }
 
+  private def testBroadcastJoin[T: ClassTag](joinType: String,
+                                             forceBroadcast: Boolean = false): SparkPlan
= {
+    val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+    var df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+
+    // Comparison at the end is for broadcast left semi join
+    val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
+    val df3 = if (forceBroadcast) {
+      df1.join(broadcast(df2), joinExpression, joinType)
+    } else {
+      df1.join(df2, joinExpression, joinType)
+    }
+    val plan =
+      EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan)
+    assert(plan.collect { case p: T => p }.size === 1)
+
+    return plan
+  }
+
   test("unsafe broadcast hash join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner")
+    testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner")
   }
 
   test("unsafe broadcast hash outer join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer")
+    testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer")
   }
 
   test("unsafe broadcast left semi join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi")
+    testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi")
+  }
+
+  test("broadcast hint isn't bothered by authBroadcastJoinThreshold set to low values") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+      testBroadcastJoin[BroadcastHashJoinExec]("inner", true)
+    }
+  }
+
+  test("broadcast hint isn't bothered by a disabled authBroadcastJoinThreshold") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      testBroadcastJoin[BroadcastHashJoinExec]("inner", true)
+    }
+  }
+
+  test("broadcast hint isn't propagated after a join") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+      val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key"))
+
+      val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value")
+      val df5 = df4.join(df3, Seq("key"), "inner")
+
+      val plan =
+        EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan)
+
+      assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
+      assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1)
+    }
   }
 
+  private def assertBroadcastJoin(df : Dataset[Row]) : Unit = {
+    val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+    val joined = df1.join(df, Seq("key"), "inner")
+
+    val plan =
+      EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan)
+
+    assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1)
+  }
+
+  test("broadcast hint is propagated correctly") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value")
+      val broadcasted = broadcast(df2)
+      val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value")
+
+      val cases = Seq(broadcasted.limit(2),
+                      broadcasted.filter("value < 10"),
+                      broadcasted.sample(true, 0.5),
+                      broadcasted.distinct(),
+                      broadcasted.groupBy("value").agg(min($"key").as("key")),
+                      // except and intersect are semi/anti-joins which won't return more
data then
+                      // their left argument, so the broadcast hint should be propagated
here
+                      broadcasted.except(df3),
+                      broadcasted.intersect(df3))
+
+      cases.foreach(assertBroadcastJoin)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message