spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-12740] [SPARK-13932] support grouping()/grouping_id() in having/order clause
Date Thu, 07 Apr 2016 18:51:40 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8dcb0c7c9 -> aa852215f


[SPARK-12740] [SPARK-13932] support grouping()/grouping_id() in having/order clause

## What changes were proposed in this pull request?

This PR brings the support of using grouping()/grouping_id() in HAVING/ORDER BY clause.

The resolved grouping()/grouping_id() will be replaced by unresolved "spark_gropuing_id" virtual
attribute, then resolved by ResolveMissingAttribute.

This PR also fix the HAVING clause that access a grouping column that is not presented in
SELECT clause, for example:
```sql
select count(1) from (select 1 as a) t group by a having a > 0
```
## How was this patch tested?

Add new tests.

Author: Davies Liu <davies@databricks.com>

Closes #12235 from davies/grouping_having.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/aa852215
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/aa852215
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/aa852215

Branch: refs/heads/master
Commit: aa852215f82876977d164f371627e894e86baacc
Parents: 8dcb0c7
Author: Davies Liu <davies@databricks.com>
Authored: Thu Apr 7 11:51:34 2016 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Thu Apr 7 11:51:34 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 181 +++++++++++++------
 .../catalyst/expressions/namedExpressions.scala |   4 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  82 +++++++++
 3 files changed, 211 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/aa852215/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bc8cf4e..7bcba42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -87,7 +87,7 @@ class Analyzer(
       ResolveGroupingAnalytics ::
       ResolvePivot ::
       ResolveOrdinalInOrderByAndGroupBy ::
-      ResolveSortReferences ::
+      ResolveMissingReferences ::
       ResolveGenerate ::
       ResolveFunctions ::
       ResolveAliases ::
@@ -228,21 +228,56 @@ class Analyzer(
       Seq.tabulate(1 << c.groupByExprs.length)(i => i)
     }
 
-    private def hasGroupingId(expr: Seq[Expression]): Boolean = {
-      expr.exists(_.collectFirst {
-        case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) =>
u
-      }.isDefined)
+    private def hasGroupingAttribute(expr: Expression): Boolean = {
+      expr.collectFirst {
+        case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName)
=> u
+      }.isDefined
     }
 
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    private def hasGroupingFunction(e: Expression): Boolean = {
+      e.collectFirst {
+        case g: Grouping => g
+        case g: GroupingID => g
+      }.isDefined
+    }
+
+    private def replaceGroupingFunc(
+        expr: Expression,
+        groupByExprs: Seq[Expression],
+        gid: Expression): Expression = {
+      expr transform {
+        case e: GroupingID =>
+          if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
+            gid
+          } else {
+            throw new AnalysisException(
+              s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match "
+
+                s"grouping columns (${groupByExprs.mkString(",")})")
+          }
+        case Grouping(col: Expression) =>
+          val idx = groupByExprs.indexOf(col)
+          if (idx >= 0) {
+            Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
+              Literal(1)), ByteType)
+          } else {
+            throw new AnalysisException(s"Column of grouping ($col) can't be found " +
+              s"in grouping columns ${groupByExprs.mkString(",")}")
+          }
+      }
+    }
+
+    // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
       case a if !a.childrenResolved => a // be sure all of the children are resolved.
+      case p if p.expressions.exists(hasGroupingAttribute) =>
+        failAnalysis(
+          s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
+
       case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
         GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
       case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
         GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
-      case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions)
=>
-        failAnalysis(
-          s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
+
       // Ensure all the expressions have been resolved.
       case x: GroupingSets if x.expressions.forall(_.resolved) =>
         val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
@@ -270,7 +305,7 @@ class Analyzer(
           def isPartOfAggregation(e: Expression): Boolean = {
             aggsBuffer.exists(a => a.find(_ eq e).isDefined)
           }
-          expr.transformDown {
+          replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
             // AggregateExpression should be computed on the unmodified value of its argument
             // expressions, so we should not replace any references to grouping expression
             // inside it.
@@ -278,23 +313,6 @@ class Analyzer(
               aggsBuffer += e
               e
             case e if isPartOfAggregation(e) => e
-            case e: GroupingID =>
-              if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
-                gid
-              } else {
-                throw new AnalysisException(
-                  s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match
" +
-                    s"grouping columns (${x.groupByExprs.mkString(",")})")
-              }
-            case Grouping(col: Expression) =>
-              val idx = x.groupByExprs.indexOf(col)
-              if (idx >= 0) {
-                Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
-                  Literal(1)), ByteType)
-              } else {
-                throw new AnalysisException(s"Column of grouping ($col) can't be found "
+
-                  s"in grouping columns ${x.groupByExprs.mkString(",")}")
-              }
             case e =>
               val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
               if (index == -1) {
@@ -306,9 +324,37 @@ class Analyzer(
         }
 
         Aggregate(
-          groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+          groupByAttributes :+ gid,
           aggregations,
           Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+
+      case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
+        val groupingExprs = findGroupingExprs(child)
+        // The unresolved grouping id will be resolved by ResolveMissingReferences
+        val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
+        f.copy(condition = newCond)
+
+      case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
+        val groupingExprs = findGroupingExprs(child)
+        val gid = VirtualColumn.groupingIdAttribute
+        // The unresolved grouping id will be resolved by ResolveMissingReferences
+        val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
+        s.copy(order = newOrder)
+    }
+
+    private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
+      plan.collectFirst {
+        case a: Aggregate =>
+          // this Aggregate should have grouping id as the last grouping key.
+          val gid = a.groupingExpressions.last
+          if (!gid.isInstanceOf[AttributeReference]
+            || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName)
{
+            failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+          }
+          a.groupingExpressions.take(a.groupingExpressions.length - 1)
+      }.getOrElse {
+        failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+      }
     }
   }
 
@@ -663,13 +709,15 @@ class Analyzer(
    * clause.  This rule detects such queries and adds the required attributes to the original
    * projection, so that they will be available during sorting. Another projection is added
to
    * remove these attributes after sorting.
+   *
+   * The HAVING clause could also used a grouping columns that is not presented in the SELECT.
    */
-  object ResolveSortReferences extends Rule[LogicalPlan] {
+  object ResolveMissingReferences extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
       case sa @ Sort(_, _, child: Aggregate) => sa
 
-      case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+      case s @ Sort(order, _, child) if child.resolved =>
         try {
           val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
           val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
@@ -689,6 +737,26 @@ class Analyzer(
           // in Sort
           case ae: AnalysisException => s
         }
+
+      case f @ Filter(cond, child) if child.resolved =>
+        try {
+          val newCond = resolveExpressionRecursively(cond, child)
+          val requiredAttrs = newCond.references.filter(_.resolved)
+          val missingAttrs = requiredAttrs -- child.outputSet
+          if (missingAttrs.nonEmpty) {
+            // Add missing attributes and then project them away.
+            Project(child.output,
+              Filter(newCond, addMissingAttr(child, missingAttrs)))
+          } else if (newCond != cond) {
+            f.copy(condition = newCond)
+          } else {
+            f
+          }
+        } catch {
+          // Attempting to resolve it might fail. When this happens, return the original
plan.
+          // Users will see an AnalysisException for resolution failure of missing attributes
+          case ae: AnalysisException => f
+        }
     }
 
     /**
@@ -843,27 +911,33 @@ class Analyzer(
           if aggregate.resolved =>
 
         // Try resolving the condition of the filter as though it is in the aggregate clause
-        val aggregatedCondition =
-          Aggregate(
-            grouping,
-            Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
-            child)
-        val resolvedOperator = execute(aggregatedCondition)
-        def resolvedAggregateFilter =
-          resolvedOperator
-            .asInstanceOf[Aggregate]
-            .aggregateExpressions.head
-
-        // If resolution was successful and we see the filter has an aggregate in it, add
it to
-        // the original aggregate operator.
-        if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter))
{
-          val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
-
-          Project(aggregate.output,
-            Filter(resolvedAggregateFilter.toAttribute,
-              aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
-        } else {
-          filter
+        try {
+          val aggregatedCondition =
+            Aggregate(
+              grouping,
+              Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
+              child)
+          val resolvedOperator = execute(aggregatedCondition)
+          def resolvedAggregateFilter =
+            resolvedOperator
+              .asInstanceOf[Aggregate]
+              .aggregateExpressions.head
+
+          // If resolution was successful and we see the filter has an aggregate in it, add
it to
+          // the original aggregate operator.
+          if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter))
{
+            val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+
+            Project(aggregate.output,
+              Filter(resolvedAggregateFilter.toAttribute,
+                aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+          } else {
+            filter
+          }
+        } catch {
+          // Attempting to resolve in the aggregate can result in ambiguity.  When this happens,
+          // just return the original plan.
+          case ae: AnalysisException => filter
         }
 
       case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
@@ -927,11 +1001,8 @@ class Analyzer(
         }
     }
 
-    private def isAggregateExpression(e: Expression): Boolean = {
-      e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
-    }
     def containsAggregate(condition: Expression): Boolean = {
-      condition.find(isAggregateExpression).isDefined
+      condition.find(_.isInstanceOf[AggregateExpression]).isDefined
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/aa852215/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 2307122..78310fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -333,6 +333,8 @@ case class PrettyAttribute(
 }
 
 object VirtualColumn {
-  val groupingIdName: String = "grouping__id"
+  // The attribute name used by Hive, which has different result than Spark, deprecated.
+  val hiveGroupingIdName: String = "grouping__id"
+  val groupingIdName: String = "spark_grouping_id"
   val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/aa852215/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 2ab7c15..dd648cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2230,6 +2230,88 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
   }
 
+  test("grouping and grouping_id in having") {
+    checkAnswer(
+      sql("select course, year from courseSales group by cube(course, year)" +
+        " having grouping(year) = 1 and grouping_id(course, year) > 0"),
+        Row("Java", null) ::
+        Row("dotNET", null) ::
+        Row(null, null) :: Nil
+    )
+
+    var error = intercept[AnalysisException] {
+      sql("select course, year from courseSales group by course, year" +
+        " having grouping(course) > 0")
+    }
+    assert(error.getMessage contains
+      "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+    error = intercept[AnalysisException] {
+      sql("select course, year from courseSales group by course, year" +
+        " having grouping_id(course, year) > 0")
+    }
+    assert(error.getMessage contains
+      "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+    error = intercept[AnalysisException] {
+      sql("select course, year from courseSales group by cube(course, year)" +
+        " having grouping__id > 0")
+    }
+    assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
+  }
+
+  test("grouping and grouping_id in sort") {
+    checkAnswer(
+      sql("select course, year, grouping(course), grouping(year) from courseSales" +
+        " group by cube(course, year) order by grouping_id(course, year), course, year"),
+      Row("Java", 2012, 0, 0) ::
+        Row("Java", 2013, 0, 0) ::
+        Row("dotNET", 2012, 0, 0) ::
+        Row("dotNET", 2013, 0, 0) ::
+        Row("Java", null, 0, 1) ::
+        Row("dotNET", null, 0, 1) ::
+        Row(null, 2012, 1, 0) ::
+        Row(null, 2013, 1, 0) ::
+        Row(null, null, 1, 1) :: Nil
+    )
+
+    checkAnswer(
+      sql("select course, year, grouping_id(course, year) from courseSales" +
+        " group by cube(course, year) order by grouping(course), grouping(year), course,
year"),
+      Row("Java", 2012, 0) ::
+        Row("Java", 2013, 0) ::
+        Row("dotNET", 2012, 0) ::
+        Row("dotNET", 2013, 0) ::
+        Row("Java", null, 1) ::
+        Row("dotNET", null, 1) ::
+        Row(null, 2012, 2) ::
+        Row(null, 2013, 2) ::
+        Row(null, null, 3) :: Nil
+    )
+
+    var error = intercept[AnalysisException] {
+      sql("select course, year from courseSales group by course, year" +
+        " order by grouping(course)")
+    }
+    assert(error.getMessage contains
+      "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+    error = intercept[AnalysisException] {
+      sql("select course, year from courseSales group by course, year" +
+        " order by grouping_id(course, year)")
+    }
+    assert(error.getMessage contains
+      "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+    error = intercept[AnalysisException] {
+      sql("select course, year from courseSales group by cube(course, year)" +
+        " order by grouping__id")
+    }
+    assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
+  }
+
+  test("filter on a grouping column that is not presented in SELECT") {
+    checkAnswer(
+      sql("select count(1) from (select 1 as a) t group by a having a > 0"),
+      Row(1) :: Nil)
+  }
+
   test("SPARK-13056: Null in map value causes NPE") {
     val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
     withTempTable("maptest") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message