spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-17972][SQL] Add Dataset.checkpoint() to truncate large query plans
Date Mon, 31 Oct 2016 20:40:03 GMT
Repository: spark
Updated Branches:
  refs/heads/master 26b07f190 -> 8bfc3b7aa


[SPARK-17972][SQL] Add Dataset.checkpoint() to truncate large query plans

## What changes were proposed in this pull request?
### Problem

Iterative ML code may easily create query plans that grow exponentially. We found that query
planning time also increases exponentially even when all the sub-plan trees are cached.

The following snippet illustrates the problem:

``` scala
(0 until 6).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) =>
  println(s"== Iteration $iteration ==")
  val time0 = System.currentTimeMillis()
  val joined = plan.join(plan, "value").join(plan, "value").join(plan, "value").join(plan,
"value")
  joined.cache()
  println(s"Query planning takes ${System.currentTimeMillis() - time0} ms")
  joined.as[Int]
}

// == Iteration 0 ==
// Query planning takes 9 ms
// == Iteration 1 ==
// Query planning takes 26 ms
// == Iteration 2 ==
// Query planning takes 53 ms
// == Iteration 3 ==
// Query planning takes 163 ms
// == Iteration 4 ==
// Query planning takes 700 ms
// == Iteration 5 ==
// Query planning takes 3418 ms
```

This is because when building a new Dataset, the new plan is always built upon `QueryExecution.analyzed`,
which doesn't leverage existing cached plans.

On the other hand, usually, doing caching every a few iterations may not be the right direction
for this problem since caching is too memory consuming (imaging computing connected components
over a graph with 50 billion nodes). What we really need here is to truncate both the query
plan (to minimize query planning time) and the lineage of the underlying RDD (to avoid stack
overflow).
### Changes introduced in this PR

This PR tries to fix this issue by introducing a `checkpoint()` method into `Dataset[T]`,
which does exactly the things described above. The following snippet, which is essentially
the same as the one above but invokes `checkpoint()` instead of `cache()`, shows the micro
benchmark result of this PR:

One key point is that the checkpointed Dataset should preserve the origianl partitioning and
ordering information of the original Dataset, so that we can avoid unnecessary shuffling (similar
to reading from a pre-bucketed table). This is done by adding `outputPartitioning` and `outputOrdering`
to `LogicalRDD` and `RDDScanExec`.
### Micro benchmark

``` scala
spark.sparkContext.setCheckpointDir("/tmp/cp")

(0 until 100).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) =>
  println(s"== Iteration $iteration ==")
  val time0 = System.currentTimeMillis()
  val cp = plan.checkpoint()
  cp.count()
  System.out.println(s"Checkpointing takes ${System.currentTimeMillis() - time0} ms")

  val time1 = System.currentTimeMillis()
  val joined = cp.join(cp, "value").join(cp, "value").join(cp, "value").join(cp, "value")
  val result = joined.as[Int]

  println(s"Query planning takes ${System.currentTimeMillis() - time1} ms")
  result
}

// == Iteration 0 ==
// Checkpointing takes 591 ms
// Query planning takes 13 ms
// == Iteration 1 ==
// Checkpointing takes 1605 ms
// Query planning takes 16 ms
// == Iteration 2 ==
// Checkpointing takes 782 ms
// Query planning takes 8 ms
// == Iteration 3 ==
// Checkpointing takes 729 ms
// Query planning takes 10 ms
// == Iteration 4 ==
// Checkpointing takes 734 ms
// Query planning takes 9 ms
// == Iteration 5 ==
// ...
// == Iteration 50 ==
// Checkpointing takes 571 ms
// Query planning takes 7 ms
// == Iteration 51 ==
// Checkpointing takes 548 ms
// Query planning takes 7 ms
// == Iteration 52 ==
// Checkpointing takes 596 ms
// Query planning takes 8 ms
// == Iteration 53 ==
// Checkpointing takes 568 ms
// Query planning takes 7 ms
// ...
```

You may see that although checkpointing is more heavy weight an operation, it always takes
roughly the same amount of time to perform both checkpointing and query planning.
### Open question

mengxr mentioned that it would be more convenient if we can make `Dataset.checkpoint()` eager,
i.e., always performs a `RDD.count()` after calling `RDD.checkpoint()`. Not quite sure whether
this is a universal requirement. Maybe we can add a `eager: Boolean` argument for `Dataset.checkpoint()`
to support that.
## How was this patch tested?

Unit test added in `DatasetSuite`.

Author: Cheng Lian <lian@databricks.com>
Author: Yin Huai <yhuai@databricks.com>

Closes #15651 from liancheng/ds-checkpoint.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8bfc3b7a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8bfc3b7a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8bfc3b7a

Branch: refs/heads/master
Commit: 8bfc3b7aac577e36aadc4fe6dee0665d0b2ae919
Parents: 26b07f1
Author: Cheng Lian <lian@databricks.com>
Authored: Mon Oct 31 13:39:59 2016 -0700
Committer: Yin Huai <yhuai@databricks.com>
Committed: Mon Oct 31 13:39:59 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 57 +++++++++++++++-
 .../spark/sql/execution/ExistingRDD.scala       | 37 +++++++++--
 .../spark/sql/execution/SparkStrategies.scala   |  7 +-
 .../org/apache/spark/sql/DatasetSuite.scala     | 68 ++++++++++++++++++++
 4 files changed, 157 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8bfc3b7a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 286d854..6e0a247 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
 import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
 import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView,
LocalTempView}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
 import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery}
+import org.apache.spark.sql.streaming.DataStreamWriter
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
@@ -483,6 +484,58 @@ class Dataset[T] private[sql](
   def isStreaming: Boolean = logicalPlan.isStreaming
 
   /**
+   * Returns a checkpointed version of this Dataset.
+   *
+   * @group basic
+   * @since 2.1.0
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def checkpoint(): Dataset[T] = checkpoint(eager = true)
+
+  /**
+   * Returns a checkpointed version of this Dataset.
+   *
+   * @param eager When true, materializes the underlying checkpointed RDD eagerly.
+   *
+   * @group basic
+   * @since 2.1.0
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def checkpoint(eager: Boolean): Dataset[T] = {
+    val internalRdd = queryExecution.toRdd.map(_.copy())
+    internalRdd.checkpoint()
+
+    if (eager) {
+      internalRdd.count()
+    }
+
+    val physicalPlan = queryExecution.executedPlan
+
+    // Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise
the
+    // size of `PartitioningCollection` may grow exponentially for queries involving deep
inner
+    // joins.
+    def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
+      partitioning match {
+        case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
+        case p => p
+      }
+    }
+
+    val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning)
+
+    Dataset.ofRows(
+      sparkSession,
+      LogicalRDD(
+        logicalPlan.output,
+        internalRdd,
+        outputPartitioning,
+        physicalPlan.outputOrdering
+      )(sparkSession)).as[T]
+  }
+
+  /**
    * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
    * and all cells will be aligned right. For example:
    * {{{

http://git-wip-us.apache.org/repos/asf/spark/blob/8bfc3b7a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index d3a2222..455fb5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.Utils
@@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T](
 /** Logical plan node for scanning data from an RDD of InternalRow. */
 case class LogicalRDD(
     output: Seq[Attribute],
-    rdd: RDD[InternalRow])(session: SparkSession)
+    rdd: RDD[InternalRow],
+    outputPartitioning: Partitioning = UnknownPartitioning(0),
+    outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession)
   extends LeafNode with MultiInstanceRelation {
 
   override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
 
-  override def newInstance(): LogicalRDD.this.type =
-    LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type]
+  override def newInstance(): LogicalRDD.this.type = {
+    val rewrite = output.zip(output.map(_.newInstance())).toMap
+
+    val rewrittenPartitioning = outputPartitioning match {
+      case p: Expression =>
+        p.transform {
+          case e: Attribute => rewrite.getOrElse(e, e)
+        }.asInstanceOf[Partitioning]
+
+      case p => p
+    }
+
+    val rewrittenOrdering = outputOrdering.map(_.transform {
+      case e: Attribute => rewrite.getOrElse(e, e)
+    }.asInstanceOf[SortOrder])
+
+    LogicalRDD(
+      output.map(rewrite),
+      rdd,
+      rewrittenPartitioning,
+      rewrittenOrdering
+    )(session).asInstanceOf[this.type]
+  }
 
   override def sameResult(plan: LogicalPlan): Boolean = {
     plan.canonicalized match {
-      case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
+      case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id
       case _ => false
     }
   }
@@ -158,7 +181,9 @@ case class LogicalRDD(
 case class RDDScanExec(
     output: Seq[Attribute],
     rdd: RDD[InternalRow],
-    override val nodeName: String) extends LeafExecNode {
+    override val nodeName: String,
+    override val outputPartitioning: Partitioning = UnknownPartitioning(0),
+    override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

http://git-wip-us.apache.org/repos/asf/spark/blob/8bfc3b7a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7cfae5c..5412aca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
 import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation,
StreamingRelation, StreamingRelationExec}
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.StreamingQuery
 
 /**
  * Converts a logical plan into zero or more SparkPlans.  This API is exposed for experimenting
@@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
       case logical.OneRowRelation =>
         execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
-      case r : logical.Range =>
+      case r: logical.Range =>
         execution.RangeExec(r) :: Nil
       case logical.RepartitionByExpression(expressions, child, nPartitions) =>
         exchange.ShuffleExchange(HashPartitioning(
           expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
       case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd)
:: Nil
-      case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil
+      case r: LogicalRDD =>
+        RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering)
:: Nil
       case BroadcastHint(child) => planLater(child) :: Nil
       case _ => Nil
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/8bfc3b7a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index cc367ac..55f0487 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
 import org.apache.spark.sql.catalyst.util.sideBySide
+import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 
@@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       df.withColumn("b", expr("0")).as[ClassData]
         .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
   }
+
+  Seq(true, false).foreach { eager =>
+    def testCheckpointing(testName: String)(f: => Unit): Unit = {
+      test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
+        withTempDir { dir =>
+          val originalCheckpointDir = spark.sparkContext.checkpointDir
+
+          try {
+            spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
+            f
+          } finally {
+            // Since the original checkpointDir can be None, we need
+            // to set the variable directly.
+            spark.sparkContext.checkpointDir = originalCheckpointDir
+          }
+        }
+      }
+    }
+
+    testCheckpointing("basic") {
+      val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
+      val cp = ds.checkpoint(eager)
+
+      val logicalRDD = cp.logicalPlan match {
+        case plan: LogicalRDD => plan
+        case _ =>
+          val treeString = cp.logicalPlan.treeString(verbose = true)
+          fail(s"Expecting a LogicalRDD, but got\n$treeString")
+      }
+
+      val dsPhysicalPlan = ds.queryExecution.executedPlan
+      val cpPhysicalPlan = cp.queryExecution.executedPlan
+
+      assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
+      assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
+
+      assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning
}
+      assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }
+
+      // For a lazy checkpoint() call, the first check also materializes the checkpoint.
+      checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
+
+      // Reads back from checkpointed data and check again.
+      checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
+    }
+
+    testCheckpointing("should preserve partitioning information") {
+      val ds = spark.range(10).repartition('id % 2)
+      val cp = ds.checkpoint(eager)
+
+      val agg = cp.groupBy('id % 2).agg(count('id))
+
+      agg.queryExecution.executedPlan.collectFirst {
+        case ShuffleExchange(_, _: RDDScanExec, _) =>
+        case BroadcastExchangeExec(_, _: RDDScanExec) =>
+      }.foreach { _ =>
+        fail(
+          "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset
" +
+            "preserves partitioning information:\n\n" + agg.queryExecution
+        )
+      }
+
+      checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
+    }
+  }
 }
 
 case class Generic[T](id: T, value: Double)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message