spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l...@apache.org
Subject spark git commit: [SPARK-13549][SQL] Refactor the Optimizer Rule CollapseProject
Date Wed, 23 Mar 2016 16:51:46 GMT
Repository: spark
Updated Branches:
  refs/heads/master cde086cb2 -> 6ce008ba4


[SPARK-13549][SQL] Refactor the Optimizer Rule CollapseProject

#### What changes were proposed in this pull request?

The PR https://github.com/apache/spark/pull/10541 changed the rule `CollapseProject` by enabling
collapsing `Project` into `Aggregate`. It leaves a to-do item to remove the duplicate code.
This PR is to finish this to-do item. Also added a test case for covering this change.

#### How was this patch tested?

Added a new test case.

liancheng Could you check if the code refactoring is fine? Thanks!

Author: gatorsmile <gatorsmile@gmail.com>

Closes #11427 from gatorsmile/collapseProjectRefactor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ce008ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ce008ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ce008ba

Branch: refs/heads/master
Commit: 6ce008ba46aa1fc8a5c222ce0f25a6d81f53588e
Parents: cde086c
Author: gatorsmile <gatorsmile@gmail.com>
Authored: Thu Mar 24 00:51:31 2016 +0800
Committer: Cheng Lian <lian@databricks.com>
Committed: Thu Mar 24 00:51:31 2016 +0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      | 101 +++++++++----------
 .../optimizer/CollapseProjectSuite.scala        |  26 ++++-
 2 files changed, 70 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ce008ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0840d46..4cfdcf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -417,68 +417,57 @@ object ColumnPruning extends Rule[LogicalPlan] {
 object CollapseProject extends Rule[LogicalPlan] {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case p @ Project(projectList1, Project(projectList2, child)) =>
-      // Create a map of Aliases to their values from the child projection.
-      // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a
+ b, c)).
-      val aliasMap = AttributeMap(projectList2.collect {
-        case a: Alias => (a.toAttribute, a)
-      })
-
-      // We only collapse these two Projects if their overlapped expressions are all
-      // deterministic.
-      val hasNondeterministic = projectList1.exists(_.collect {
-        case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
-      }.exists(!_.deterministic))
-
-      if (hasNondeterministic) {
+    case p1 @ Project(_, p2: Project) =>
+      if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
+        p1
+      } else {
+        p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
+      }
+    case p @ Project(_, agg: Aggregate) =>
+      if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
         p
       } else {
-        // Substitute any attributes that are produced by the child projection, so that we
safely
-        // eliminate it.
-        // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
-        // TODO: Fix TransformBase to avoid the cast below.
-        val substitutedProjection = projectList1.map(_.transform {
-          case a: Attribute => aliasMap.getOrElse(a, a)
-        }).asInstanceOf[Seq[NamedExpression]]
-        // collapse 2 projects may introduce unnecessary Aliases, trim them here.
-        val cleanedProjection = substitutedProjection.map(p =>
-          CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
-        )
-        Project(cleanedProjection, child)
+        agg.copy(aggregateExpressions = buildCleanedProjectList(
+          p.projectList, agg.aggregateExpressions))
       }
+  }
 
-    // TODO Eliminate duplicate code
-    // This clause is identical to the one above except that the inner operator is an `Aggregate`
-    // rather than a `Project`.
-    case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
-      // Create a map of Aliases to their values from the child projection.
-      // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a
+ b, c)).
-      val aliasMap = AttributeMap(projectList2.collect {
-        case a: Alias => (a.toAttribute, a)
-      })
+  private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
+    AttributeMap(projectList.collect {
+      case a: Alias => a.toAttribute -> a
+    })
+  }
 
-      // We only collapse these two Projects if their overlapped expressions are all
-      // deterministic.
-      val hasNondeterministic = projectList1.exists(_.collect {
-        case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
-      }.exists(!_.deterministic))
+  private def haveCommonNonDeterministicOutput(
+      upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
+    // Create a map of Aliases to their values from the lower projection.
+    // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a +
b, c)).
+    val aliases = collectAliases(lower)
+
+    // Collapse upper and lower Projects if and only if their overlapped expressions are
all
+    // deterministic.
+    upper.exists(_.collect {
+      case a: Attribute if aliases.contains(a) => aliases(a).child
+    }.exists(!_.deterministic))
+  }
 
-      if (hasNondeterministic) {
-        p
-      } else {
-        // Substitute any attributes that are produced by the child projection, so that we
safely
-        // eliminate it.
-        // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
-        // TODO: Fix TransformBase to avoid the cast below.
-        val substitutedProjection = projectList1.map(_.transform {
-          case a: Attribute => aliasMap.getOrElse(a, a)
-        }).asInstanceOf[Seq[NamedExpression]]
-        // collapse 2 projects may introduce unnecessary Aliases, trim them here.
-        val cleanedProjection = substitutedProjection.map(p =>
-          CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
-        )
-        agg.copy(aggregateExpressions = cleanedProjection)
-      }
+  private def buildCleanedProjectList(
+      upper: Seq[NamedExpression],
+      lower: Seq[NamedExpression]): Seq[NamedExpression] = {
+    // Create a map of Aliases to their values from the lower projection.
+    // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a +
b, c)).
+    val aliases = collectAliases(lower)
+
+    // Substitute any attributes that are produced by the lower projection, so that we safely
+    // eliminate it.
+    // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
+    val rewrittenUpper = upper.map(_.transform {
+      case a: Attribute => aliases.getOrElse(a, a)
+    })
+    // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
+    rewrittenUpper.map { p =>
+      CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6ce008ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
index 833f054..587437e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
@@ -29,7 +29,7 @@ class CollapseProjectSuite extends PlanTest {
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches =
       Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) ::
-        Batch("CollapseProject", Once, CollapseProject) :: Nil
+      Batch("CollapseProject", Once, CollapseProject) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int)
@@ -95,4 +95,28 @@ class CollapseProjectSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer)
   }
+
+  test("collapse project into aggregate") {
+    val query = testRelation
+      .groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b)
+      .select('a_plus_1, ('b + 1).as('b_plus_1))
+
+    val optimized = Optimize.execute(query.analyze)
+
+    val correctAnswer = testRelation
+      .groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("do not collapse common nondeterministic project and aggregate") {
+    val query = testRelation
+      .groupBy('a)('a, Rand(10).as('rand))
+      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))
+
+    val optimized = Optimize.execute(query.analyze)
+    val correctAnswer = query.analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message