spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer
Date Fri, 13 Apr 2018 17:01:08 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4dfd746de -> 25892f3cc


[SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer

## What changes were proposed in this pull request?

Added a new rule to remove Sort operation when its child is already sorted.
For instance, this simple code:
```
spark.sparkContext.parallelize(Seq(("a", "b"))).toDF("a", "b").registerTempTable("table1")
val df = sql(s"""SELECT b
                | FROM (
                |     SELECT a, b
                |     FROM table1
                |     ORDER BY a
                | ) t
                | ORDER BY a""".stripMargin)
df.explain(true)
```
before the PR produces this plan:
```
== Parsed Logical Plan ==
'Sort ['a ASC NULLS FIRST], true
+- 'Project ['b]
   +- 'SubqueryAlias t
      +- 'Sort ['a ASC NULLS FIRST], true
         +- 'Project ['a, 'b]
            +- 'UnresolvedRelation `table1`

== Analyzed Logical Plan ==
b: string
Project [b#7]
+- Sort [a#6 ASC NULLS FIRST], true
   +- Project [b#7, a#6]
      +- SubqueryAlias t
         +- Sort [a#6 ASC NULLS FIRST], true
            +- Project [a#6, b#7]
               +- SubqueryAlias table1
                  +- Project [_1#3 AS a#6, _2#4 AS b#7]
                     +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String,
StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true,
false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString,
assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4]
                        +- ExternalRDD [obj#2]

== Optimized Logical Plan ==
Project [b#7]
+- Sort [a#6 ASC NULLS FIRST], true
   +- Project [b#7, a#6]
      +- Sort [a#6 ASC NULLS FIRST], true
         +- Project [_1#3 AS a#6, _2#4 AS b#7]
            +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String,
StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0,
scala.Tuple2, true])._2, true, false) AS _2#4]
               +- ExternalRDD [obj#2]

== Physical Plan ==
*(3) Project [b#7]
+- *(3) Sort [a#6 ASC NULLS FIRST], true, 0
   +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200)
      +- *(2) Project [b#7, a#6]
         +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0
            +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200)
               +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7]
                  +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String,
StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0,
scala.Tuple2, true])._2, true, false) AS _2#4]
                     +- Scan ExternalRDDScan[obj#2]
```

while after the PR produces:

```
== Parsed Logical Plan ==
'Sort ['a ASC NULLS FIRST], true
+- 'Project ['b]
   +- 'SubqueryAlias t
      +- 'Sort ['a ASC NULLS FIRST], true
         +- 'Project ['a, 'b]
            +- 'UnresolvedRelation `table1`

== Analyzed Logical Plan ==
b: string
Project [b#7]
+- Sort [a#6 ASC NULLS FIRST], true
   +- Project [b#7, a#6]
      +- SubqueryAlias t
         +- Sort [a#6 ASC NULLS FIRST], true
            +- Project [a#6, b#7]
               +- SubqueryAlias table1
                  +- Project [_1#3 AS a#6, _2#4 AS b#7]
                     +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String,
StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true,
false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString,
assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4]
                        +- ExternalRDD [obj#2]

== Optimized Logical Plan ==
Project [b#7]
+- Sort [a#6 ASC NULLS FIRST], true
   +- Project [_1#3 AS a#6, _2#4 AS b#7]
      +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String,
StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0,
scala.Tuple2, true])._2, true, false) AS _2#4]
         +- ExternalRDD [obj#2]

== Physical Plan ==
*(2) Project [b#7]
+- *(2) Sort [a#6 ASC NULLS FIRST], true, 0
   +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 5)
      +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7]
         +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String,
StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0,
scala.Tuple2, true])._2, true, false) AS _2#4]
            +- Scan ExternalRDDScan[obj#2]
```

this means that an unnecessary sort operation is not performed after the PR.

## How was this patch tested?

added UT

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20560 from mgaido91/SPARK-23375.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25892f3c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25892f3c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25892f3c

Branch: refs/heads/master
Commit: 25892f3cc9dcb938220be8020a5b9a17c92dbdbe
Parents: 4dfd746
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Sat Apr 14 01:01:00 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Sat Apr 14 01:01:00 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |  12 +++
 .../catalyst/plans/logical/LogicalPlan.scala    |   9 ++
 .../plans/logical/basicLogicalOperators.scala   |  23 +++--
 .../optimizer/RemoveRedundantSortsSuite.scala   | 101 +++++++++++++++++++
 .../spark/sql/execution/CacheManager.scala      |   4 +-
 .../spark/sql/execution/ExistingRDD.scala       |   2 +-
 .../execution/columnar/InMemoryRelation.scala   |  17 ++--
 .../apache/spark/sql/ConfigBehaviorSuite.scala  |   2 +-
 .../spark/sql/execution/PlannerSuite.scala      |  15 ++-
 .../columnar/InMemoryColumnarQuerySuite.scala   |  14 +--
 10 files changed, 175 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9a1bbc6..5fb59ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
     operatorOptimizationBatch) :+
     Batch("Join Reorder", Once,
       CostBasedJoinReorder) :+
+    Batch("Remove Redundant Sorts", Once,
+      RemoveRedundantSorts) :+
     Batch("Decimal Optimizations", fixedPoint,
       DecimalAggregates) :+
     Batch("Object Expressions Optimization", fixedPoint,
@@ -734,6 +736,16 @@ object EliminateSorts extends Rule[LogicalPlan] {
 }
 
 /**
+ * Removes Sort operation if the child is already sorted
+ */
+object RemoveRedundantSorts extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders)
=>
+      child
+  }
+}
+
+/**
  * Removes filters that can be evaluated trivially.  This can be done through the following
ways:
  * 1) by eliding the filter for cases where it will always evaluate to `true`.
  * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`.

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index c8ccd9b..4203440 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -219,6 +219,11 @@ abstract class LogicalPlan
    * Refreshes (or invalidates) any metadata/data cached in the plan recursively.
    */
   def refresh(): Unit = children.foreach(_.refresh())
+
+  /**
+   * Returns the output ordering that this plan generates.
+   */
+  def outputOrdering: Seq[SortOrder] = Nil
 }
 
 /**
@@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan {
 
   override final def children: Seq[LogicalPlan] = Seq(left, right)
 }
+
+abstract class OrderPreservingUnaryNode extends UnaryNode {
+  override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index a4fca79..10df504 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
  * This node is inserted at the top of a subquery when it is optimized. This makes sure we
can
  * recognize a subquery as such, and it allows us to write subquery aware transformations.
  */
-case class Subquery(child: LogicalPlan) extends UnaryNode {
+case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {
   override def output: Seq[Attribute] = child.output
 }
 
-case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode
{
+case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
+    extends OrderPreservingUnaryNode {
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
   override def maxRows: Option[Long] = child.maxRows
 
@@ -125,7 +126,7 @@ case class Generate(
 }
 
 case class Filter(condition: Expression, child: LogicalPlan)
-  extends UnaryNode with PredicateHelper {
+  extends OrderPreservingUnaryNode with PredicateHelper {
   override def output: Seq[Attribute] = child.output
 
   override def maxRows: Option[Long] = child.maxRows
@@ -469,6 +470,7 @@ case class Sort(
     child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
   override def maxRows: Option[Long] = child.maxRows
+  override def outputOrdering: Seq[SortOrder] = order
 }
 
 /** Factory for constructing new `Range` nodes. */
@@ -522,6 +524,15 @@ case class Range(
   override def computeStats(): Statistics = {
     Statistics(sizeInBytes = LongType.defaultSize * numElements)
   }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    val order = if (step > 0) {
+      Ascending
+    } else {
+      Descending
+    }
+    output.map(a => SortOrder(a, order))
+  }
 }
 
 case class Aggregate(
@@ -728,7 +739,7 @@ object Limit {
  *
  * See [[Limit]] for more information.
  */
-case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
+case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode
{
   override def output: Seq[Attribute] = child.output
   override def maxRows: Option[Long] = {
     limitExpr match {
@@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends
UnaryN
  *
  * See [[Limit]] for more information.
  */
-case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
+case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode
{
   override def output: Seq[Attribute] = child.output
 
   override def maxRowsPerPartition: Option[Long] = {
@@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends
UnaryNo
 case class SubqueryAlias(
     alias: String,
     child: LogicalPlan)
-  extends UnaryNode {
+  extends OrderPreservingUnaryNode {
 
   override def doCanonicalize(): LogicalPlan = child.canonicalized
 

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala
new file mode 100644
index 0000000..2319ab8
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}
+
+class RemoveRedundantSortsSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Remove Redundant Sorts", Once,
+        RemoveRedundantSorts) ::
+      Batch("Collapse Project", Once,
+        CollapseProject) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+  test("remove redundant order by") {
+    val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
+    val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
+    val optimized = Optimize.execute(unnecessaryReordered.analyze)
+    val correctAnswer = orderedPlan.select('a).analyze
+    comparePlans(Optimize.execute(optimized), correctAnswer)
+  }
+
+  test("do not remove sort if the order is different") {
+    val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
+    val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
+    val optimized = Optimize.execute(reorderedDifferently.analyze)
+    val correctAnswer = reorderedDifferently.analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("filters don't affect order") {
+    val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
+    val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
+    val optimized = Optimize.execute(filteredAndReordered.analyze)
+    val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("limits don't affect order") {
+    val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
+    val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
+    val optimized = Optimize.execute(filteredAndReordered.analyze)
+    val correctAnswer = orderedPlan.limit(Literal(10)).analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("range is already sorted") {
+    val inputPlan = Range(1L, 1000L, 1, 10)
+    val orderedPlan = inputPlan.orderBy('id.asc)
+    val optimized = Optimize.execute(orderedPlan.analyze)
+    val correctAnswer = inputPlan.analyze
+    comparePlans(optimized, correctAnswer)
+
+    val reversedPlan = inputPlan.orderBy('id.desc)
+    val reversedOptimized = Optimize.execute(reversedPlan.analyze)
+    val reversedCorrectAnswer = reversedPlan.analyze
+    comparePlans(reversedOptimized, reversedCorrectAnswer)
+
+    val negativeStepInputPlan = Range(10L, 1L, -1, 10)
+    val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
+    val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
+    val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
+    comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
+  }
+
+  test("sort should not be removed when there is a node which doesn't guarantee any order")
{
+    val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
+    val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
+    val optimized = Optimize.execute(groupedAndResorted.analyze)
+    val correctAnswer = groupedAndResorted.analyze
+    comparePlans(optimized, correctAnswer)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index d68aeb2..a8794be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -99,7 +99,7 @@ class CacheManager extends Logging {
         sparkSession.sessionState.conf.columnBatchSize, storageLevel,
         sparkSession.sessionState.executePlan(planToCache).executedPlan,
         tableName,
-        planToCache.stats)
+        planToCache)
       cachedData.add(CachedData(planToCache, inMemoryRelation))
     }
   }
@@ -148,7 +148,7 @@ class CacheManager extends Logging {
           storageLevel = cd.cachedRepresentation.storageLevel,
           child = spark.sessionState.executePlan(cd.plan).executedPlan,
           tableName = cd.cachedRepresentation.tableName,
-          statsOfPlanToCache = cd.plan.stats)
+          logicalPlan = cd.plan)
         needToRecache += cd.copy(cachedRepresentation = newCache)
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index f355550..be50a15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -125,7 +125,7 @@ case class LogicalRDD(
     output: Seq[Attribute],
     rdd: RDD[InternalRow],
     outputPartitioning: Partitioning = UnknownPartitioning(0),
-    outputOrdering: Seq[SortOrder] = Nil,
+    override val outputOrdering: Seq[SortOrder] = Nil,
     override val isStreaming: Boolean = false)(session: SparkSession)
   extends LeafNode with MultiInstanceRelation {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 2579046..a7ba9b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.LongAccumulator
@@ -39,9 +39,9 @@ object InMemoryRelation {
       storageLevel: StorageLevel,
       child: SparkPlan,
       tableName: Option[String],
-      statsOfPlanToCache: Statistics): InMemoryRelation =
+      logicalPlan: LogicalPlan): InMemoryRelation =
     new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
-      statsOfPlanToCache = statsOfPlanToCache)
+      statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
 }
 
 
@@ -64,7 +64,8 @@ case class InMemoryRelation(
     tableName: Option[String])(
     @transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
     val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
-    statsOfPlanToCache: Statistics)
+    statsOfPlanToCache: Statistics,
+    override val outputOrdering: Seq[SortOrder])
   extends logical.LeafNode with MultiInstanceRelation {
 
   override protected def innerChildren: Seq[SparkPlan] = Seq(child)
@@ -76,7 +77,8 @@ case class InMemoryRelation(
       tableName = None)(
       _cachedColumnBuffers,
       sizeInBytesStats,
-      statsOfPlanToCache)
+      statsOfPlanToCache,
+      outputOrdering)
 
   override def producedAttributes: AttributeSet = outputSet
 
@@ -159,7 +161,7 @@ case class InMemoryRelation(
   def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
     InMemoryRelation(
       newOutput, useCompression, batchSize, storageLevel, child, tableName)(
-        _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
+        _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering)
   }
 
   override def newInstance(): this.type = {
@@ -172,7 +174,8 @@ case class InMemoryRelation(
       tableName)(
         _cachedColumnBuffers,
         sizeInBytesStats,
-        statsOfPlanToCache).asInstanceOf[this.type]
+        statsOfPlanToCache,
+        outputOrdering).asInstanceOf[this.type]
   }
 
   def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
index cee85ec..949505e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala
@@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext {
     def computeChiSquareTest(): Double = {
       val n = 10000
       // Trigger a sort
-      val data = spark.range(0, n, 1, 1).sort('id)
+      val data = spark.range(0, n, 1, 1).sort('id.desc)
         .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect()
 
       // Compute histogram for the number of records per partition post sort

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index f8b26f5..40915a1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange,
ShuffleExchangeExec}
@@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext {
     assert(planned.child.isInstanceOf[CollectLimitExec])
   }
 
+  test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
+    val query = testData.select('key, 'value).sort('key.desc).cache()
+    assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
+    val resorted = query.sort('key.desc)
+    assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
+    assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
+      (1 to 100).reverse)
+    // with a different order, the sort is needed
+    val sortedAsc = query.sort('key)
+    assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size ==
1)
+    assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
+  }
+
   test("PartitioningCollection") {
     withTempView("normal", "small", "tiny") {
       testData.createOrReplaceTempView("normal")

http://git-wip-us.apache.org/repos/asf/spark/blob/25892f3c/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 26b63e8..9b7b316 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.{DataFrame, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec}
 import org.apache.spark.sql.functions._
@@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext
{
     val storageLevel = MEMORY_ONLY
     val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
     val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan,
None,
-      data.logicalPlan.stats)
+      data.logicalPlan)
 
     assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
     inMemoryRelation.cachedColumnBuffers.collect().head match {
@@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext
{
   test("simple columnar query") {
     val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
     val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None,
-      testData.logicalPlan.stats)
+      testData.logicalPlan)
 
     checkAnswer(scan, testData.collect().toSeq)
   }
@@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext
{
     val logicalPlan = testData.select('value, 'key).logicalPlan
     val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan
     val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None,
-      logicalPlan.stats)
+      logicalPlan)
 
     checkAnswer(scan, testData.collect().map {
       case Row(key: Int, value: String) => value -> key
@@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext
{
   test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times")
{
     val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
     val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None,
-      testData.logicalPlan.stats)
+      testData.logicalPlan)
 
     checkAnswer(scan, testData.collect().toSeq)
     checkAnswer(scan, testData.collect().toSeq)
@@ -329,7 +330,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext
{
   test("SPARK-17549: cached table size should be correctly calculated") {
     val data = spark.sparkContext.parallelize(1 to 10, 5).toDF()
     val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
-    val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats)
+    val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan)
 
     // Materialize the data.
     val expectedAnswer = data.collect()
@@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext
{
   test("SPARK-22249: buildFilter should not throw exception when In contains an empty list")
{
     val attribute = AttributeReference("a", IntegerType)()
     val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil)
-    val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None,
null)
+    val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None,
+      LocalRelation(Seq(attribute), Nil))
     val tableScanExec = InMemoryTableScanExec(Seq(attribute),
       Seq(In(attribute, Nil)), testRelation)
     assert(tableScanExec.partitionFilters.isEmpty)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message