spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-14600] [SQL] Push predicates through Expand
Date Wed, 20 Apr 2016 04:55:03 GMT
Repository: spark
Updated Branches:
  refs/heads/master 85d759ca3 -> 856bc465d


[SPARK-14600] [SQL] Push predicates through Expand

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-14600

This PR makes `Expand.output` have different attributes from the grouping attributes produced
by the underlying `Project`, as they have different meaning, so that we can safely push down
filter through `Expand`

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12496 from cloud-fan/expand.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/856bc465
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/856bc465
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/856bc465

Branch: refs/heads/master
Commit: 856bc465d53ccfdfda75c82c85d7f318a5158088
Parents: 85d759c
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Tue Apr 19 21:53:19 2016 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Tue Apr 19 21:53:19 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala       | 12 ++++++------
 .../spark/sql/catalyst/optimizer/Optimizer.scala     |  2 --
 .../sql/catalyst/plans/logical/basicOperators.scala  |  5 ++++-
 .../sql/catalyst/optimizer/FilterPushdownSuite.scala | 15 +++++++++++++++
 .../scala/org/apache/spark/sql/hive/SQLBuilder.scala | 14 +++++++++-----
 5 files changed, 34 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/856bc465/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2364769..8595762 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -296,10 +296,13 @@ class Analyzer(
 
         val nonNullBitmask = x.bitmasks.reduce(_ & _)
 
-        val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
+        val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
           a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
         }
 
+        val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
+        val groupingAttrs = expand.output.drop(x.child.output.length)
+
         val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
           // collect all the found AggregateExpression, so we can check an expression is
part of
           // any AggregateExpression or not.
@@ -321,15 +324,12 @@ class Analyzer(
               if (index == -1) {
                 e
               } else {
-                groupByAttributes(index)
+                groupingAttrs(index)
               }
           }.asInstanceOf[NamedExpression]
         }
 
-        Aggregate(
-          groupByAttributes :+ gid,
-          aggregations,
-          Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+        Aggregate(groupingAttrs, aggregations, expand)
 
       case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
         val groupingExprs = findGroupingExprs(child)

http://git-wip-us.apache.org/repos/asf/spark/blob/856bc465/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ecc2d77..e6d5545 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper
{
     case filter @ Filter(_, f: Filter) => filter
     // should not push predicates through sample, or will generate different results.
     case filter @ Filter(_, s: Sample) => filter
-    // TODO: push predicates through expand
-    case filter @ Filter(_, e: Expand) => filter
 
     case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic)
=>
       pushDownPredicate(filter, u.child) { predicate =>

http://git-wip-us.apache.org/repos/asf/spark/blob/856bc465/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d4fc9e4..a445ce6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -516,7 +516,10 @@ private[sql] object Expand {
       // groupingId is the last output, here we use the bit mask as the concrete value for
it.
       } :+ Literal.create(bitmask, IntegerType)
     }
-    val output = child.output ++ groupByAttrs :+ gid
+
+    // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
+    // grouping expression or null, so here we create new instance of it.
+    val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
     Expand(projections, output, Project(child.output ++ groupByAliases, child))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/856bc465/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index df7529d..9174b4e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer)
   }
+
+  test("expand") {
+    val agg = testRelation
+      .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))
+      .analyze
+      .asInstanceOf[Aggregate]
+
+    val a = agg.output(0)
+    val b = agg.output(1)
+
+    val query = agg.where(a > 1 && b > 2)
+    val optimized = Optimize.execute(query)
+    val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze
+    comparePlans(optimized, correctedAnswer)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/856bc465/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index e54358e..2d44813 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends
Loggi
 
   private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
     assert(a.child == e && e.child == p)
-    a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
-      sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+    a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
+      e.output.drop(p.child.output.length),
+      a.groupingExpressions.map(_.asInstanceOf[Attribute]))
   }
 
   private def groupingSetToSQL(
@@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends
Loggi
 
     val numOriginalOutput = project.child.output.length
     // Assumption: Aggregate's groupingExpressions is composed of
-    // 1) the attributes of aliased group by expressions
+    // 1) the grouping attributes
     // 2) gid, which is always the last one
     val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
     // Assumption: Project's projectList is composed of
     // 1) the original output (Project's child.output),
     // 2) the aliased group by expressions.
+    val expandedAttributes = project.output.drop(numOriginalOutput)
     val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
     val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
 
     // a map from group by attributes to the original group by expressions.
     val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+    // a map from expanded attributes to the original group by expressions.
+    val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs))
 
     val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
       // Assumption: expand.projections is composed of
       // 1) the original output (Project's child.output),
-      // 2) group by attributes(or null literal)
+      // 2) expanded attributes(or null literal)
       // 3) gid, which is always the last one in each project in Expand
       project.drop(numOriginalOutput).dropRight(1).collect {
-        case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
+        case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr)
       }
     }
     val groupingSetSQL = "GROUPING SETS(" +


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message