spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-13957][SQL] Support Group By Ordinal in SQL
Date Fri, 25 Mar 2016 04:56:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0874ff3aa -> 05f652d6c


[SPARK-13957][SQL] Support Group By Ordinal in SQL

#### What changes were proposed in this pull request?
This PR is to support group by position in SQL. For example, when users input the following
query
```SQL
select c1 as a, c2, c3, sum(*) from tbl group by 1, 3, c4
```
The ordinals are recognized as the positions in the select list. Thus, `Analyzer` converts
it to
```SQL
select c1, c2, c3, sum(*) from tbl group by c1, c3, c4
```

This is controlled by the config option `spark.sql.groupByOrdinal`.
- When true, the ordinal numbers in group by clauses are treated as the position in the select
list.
- When false, the ordinal numbers are ignored.
- Only convert integer literals (not foldable expressions). If found foldable expressions,
ignore them.
- When the positions specified in the group by clauses correspond to the aggregate functions
in select list, output an exception message.
- star is not allowed to use in the select list when users specify ordinals in group by

Note: This PR is taken from https://github.com/apache/spark/pull/10731. When merging this
PR, please give the credit to zhichao-li

Also cc all the people who are involved in the previous discussion:  rxin cloud-fan marmbrus
yhuai hvanhovell adrian-wang chenghao-intel tejasapatil

#### How was this patch tested?

Added a few test cases for both positive and negative test cases.

Author: gatorsmile <gatorsmile@gmail.com>
Author: xiaoli <lixiao1983@gmail.com>
Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local>

Closes #11846 from gatorsmile/groupByOrdinal.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/05f652d6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/05f652d6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/05f652d6

Branch: refs/heads/master
Commit: 05f652d6c2bbd764a1dd5a45301811e14519486f
Parents: 0874ff3
Author: gatorsmile <gatorsmile@gmail.com>
Authored: Fri Mar 25 12:55:58 2016 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri Mar 25 12:55:58 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/CatalystConf.scala       |  8 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 72 +++++++++++----
 .../spark/sql/catalyst/planning/patterns.scala  |  3 +-
 .../org/apache/spark/sql/internal/SQLConf.scala |  6 ++
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 92 ++++++++++++++++++--
 5 files changed, 156 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index e10ab97..d5ac015 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -23,6 +23,7 @@ private[spark] trait CatalystConf {
   def caseSensitiveAnalysis: Boolean
 
   def orderByOrdinal: Boolean
+  def groupByOrdinal: Boolean
 
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used to determin
if two
@@ -48,11 +49,16 @@ object EmptyConf extends CatalystConf {
   override def orderByOrdinal: Boolean = {
     throw new UnsupportedOperationException
   }
+  override def groupByOrdinal: Boolean = {
+    throw new UnsupportedOperationException
+  }
 }
 
 /** A CatalystConf that can be used for local testing. */
 case class SimpleCatalystConf(
     caseSensitiveAnalysis: Boolean,
-    orderByOrdinal: Boolean = true)
+    orderByOrdinal: Boolean = true,
+    groupByOrdinal: Boolean = true)
+
   extends CatalystConf {
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 07b0f5e..d0a31e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -85,6 +85,7 @@ class Analyzer(
       ResolveGroupingAnalytics ::
       ResolvePivot ::
       ResolveUpCast ::
+      ResolveOrdinalInOrderByAndGroupBy ::
       ResolveSortReferences ::
       ResolveGenerate ::
       ResolveFunctions ::
@@ -385,7 +386,13 @@ class Analyzer(
         p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
       // If the aggregate function argument contains Stars, expand it.
       case a: Aggregate if containsStar(a.aggregateExpressions) =>
-        a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
+        if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty))
{
+          failAnalysis(
+            "Group by position: star is not allowed to use in the select list " +
+              "when using ordinals in group by")
+        } else {
+          a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions,
a.child))
+        }
       // If the script transformation input contains Stars, expand it.
       case t: ScriptTransformation if containsStar(t.input) =>
         t.copy(
@@ -634,21 +641,23 @@ class Analyzer(
     }
   }
 
-  /**
-   * In many dialects of SQL it is valid to sort by attributes that are not present in the
SELECT
-   * clause.  This rule detects such queries and adds the required attributes to the original
-   * projection, so that they will be available during sorting. Another projection is added
to
-   * remove these attributes after sorting.
-   *
-   * This rule also resolves the position number in sort references. This support is introduced
-   * in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting.
-   * - When the sort references are not integer but foldable expressions, ignore them.
-   * - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too.
-   */
-  object ResolveSortReferences extends Rule[LogicalPlan] {
+ /**
+  * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group
by
+  * clauses. This rule is to convert ordinal positions to the corresponding expressions in
the
+  * select list. This support is introduced in Spark 2.0.
+  *
+  * - When the sort references or group by expressions are not integer but foldable expressions,
+  * just ignore them.
+  * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the
position
+  * numbers too.
+  *
+  * Before the release of Spark 2.0, the literals in order/sort by and group by clauses
+  * have no effect on the results.
+  */
+  object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case s: Sort if !s.child.resolved => s
-      // Replace the index with the related attribute for ORDER BY
+      case p if !p.childrenResolved => p
+      // Replace the index with the related attribute for ORDER BY,
       // which is a 1-base position of the projection list.
       case s @ Sort(orders, global, child)
           if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty)
=>
@@ -665,10 +674,41 @@ class Analyzer(
         }
         Sort(newOrders, global, child)
 
+      // Replace the index with the corresponding expression in aggregateExpressions. The
index is
+      // a 1-base position of aggregateExpressions, which is output columns (select expression)
+      case a @ Aggregate(groups, aggs, child)
+          if conf.groupByOrdinal && aggs.forall(_.resolved) &&
+            groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
+        val newGroups = groups.map {
+          case IntegerIndex(index) if index > 0 && index <= aggs.size =>
+            aggs(index - 1) match {
+              case e if ResolveAggregateFunctions.containsAggregate(e) =>
+                throw new UnresolvedException(a,
+                  s"Group by position: the '$index'th column in the select contains an "
+
+                  s"aggregate function: ${e.sql}. Aggregate functions are not allowed in
GROUP BY")
+              case o => o
+            }
+          case IntegerIndex(index) =>
+            throw new UnresolvedException(a,
+              s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.")
+          case o => o
+        }
+        Aggregate(newGroups, aggs, child)
+    }
+  }
+
+  /**
+   * In many dialects of SQL it is valid to sort by attributes that are not present in the
SELECT
+   * clause.  This rule detects such queries and adds the required attributes to the original
+   * projection, so that they will be available during sorting. Another projection is added
to
+   * remove these attributes after sorting.
+   */
+  object ResolveSortReferences extends Rule[LogicalPlan] {
+    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
       case sa @ Sort(_, _, child: Aggregate) => sa
 
-      case s @ Sort(order, _, child) if !s.resolved =>
+      case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
         try {
           val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
           val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)

http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index ada8424..9c92707 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -210,7 +210,8 @@ object Unions {
 object IntegerIndex {
   def unapply(a: Any): Option[Int] = a match {
     case Literal(a: Int, IntegerType) => Some(a)
-    // When resolving ordinal in Sort, negative values are extracted for issuing error messages.
+    // When resolving ordinal in Sort and Group By, negative values are extracted
+    // for issuing error messages.
     case UnaryMinus(IntegerLiteral(v)) => Some(-v)
     case _ => None
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 863a876..77af0e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -445,6 +445,11 @@ object SQLConf {
     doc = "When true, the ordinal numbers are treated as the position in the select list.
" +
           "When false, the ordinal numbers in order/sort By clause are ignored.")
 
+  val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal",
+    defaultValue = Some(true),
+    doc = "When true, the ordinal numbers in group by clauses are treated as the position
" +
+      "in the select list. When false, the ordinal numbers are ignored.")
+
   // The output committer class used by HadoopFsRelation. The specified class needs to be
a
   // subclass of org.apache.hadoop.mapreduce.OutputCommitter.
   //
@@ -668,6 +673,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with
Loggin
 
   override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
 
+  override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */

http://git-wip-us.apache.org/repos/asf/spark/blob/05f652d6/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index eb486a1..61358fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.catalyst.analysis.UnresolvedException
 import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
 import org.apache.spark.sql.functions._
@@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
   }
 
-  test("literal in agg grouping expressions") {
+  test("Group By Ordinal - basic") {
     checkAnswer(
-      sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
-      Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
-    checkAnswer(
-      sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
-      Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+      sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"),
+      sql("SELECT a, sum(b) FROM testData2 GROUP BY a"))
 
+    // duplicate group-by columns
     checkAnswer(
       sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
       sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+
+    checkAnswer(
+      sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"),
+      sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+  }
+
+  test("Group By Ordinal - non aggregate expressions") {
+    checkAnswer(
+      sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"),
+      sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+
+    checkAnswer(
+      sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"),
+      sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+  }
+
+  test("Group By Ordinal - non-foldable constant expression") {
+    checkAnswer(
+      sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"),
+      sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
+
     checkAnswer(
       sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
       sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+  }
+
+  test("Group By Ordinal - alias") {
+    checkAnswer(
+      sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"),
+      sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+
+    checkAnswer(
+      sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"),
+      sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
+  }
+
+  test("Group By Ordinal - constants") {
     checkAnswer(
       sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
       sql("SELECT 1, 2, sum(b) FROM testData2"))
   }
 
+  test("Group By Ordinal - negative cases") {
+    intercept[UnresolvedException[Aggregate]] {
+      sql("SELECT a, b FROM testData2 GROUP BY -1")
+    }
+
+    intercept[UnresolvedException[Aggregate]] {
+      sql("SELECT a, b FROM testData2 GROUP BY 3")
+    }
+
+    var e = intercept[UnresolvedException[Aggregate]](
+      sql("SELECT SUM(a) FROM testData2 GROUP BY 1"))
+    assert(e.getMessage contains
+      "Invalid call to Group by position: the '1'th column in the select contains " +
+        "an aggregate function")
+
+    e = intercept[UnresolvedException[Aggregate]](
+      sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1"))
+    assert(e.getMessage contains
+      "Invalid call to Group by position: the '1'th column in the select contains " +
+        "an aggregate function")
+
+    var ae = intercept[AnalysisException](
+      sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2"))
+    assert(ae.getMessage contains
+      "nondeterministic expression rand(0) should not appear in grouping expression")
+
+    ae = intercept[AnalysisException](
+      sql("SELECT * FROM testData2 GROUP BY a, b, 1"))
+    assert(ae.getMessage contains
+      "Group by position: star is not allowed to use in the select list " +
+        "when using ordinals in group by")
+  }
+
+  test("Group By Ordinal: spark.sql.groupByOrdinal=false") {
+    withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
+      // If spark.sql.groupByOrdinal=false, ignore the position number.
+      intercept[AnalysisException] {
+        sql("SELECT a, sum(b) FROM testData2 GROUP BY 1")
+      }
+      // '*' is not allowed to use in the select list when users specify ordinals in group
by
+      checkAnswer(
+        sql("SELECT * FROM testData2 GROUP BY a, b, 1"),
+        sql("SELECT * FROM testData2 GROUP BY a, b"))
+    }
+  }
+
   test("aggregates with nulls") {
     checkAnswer(
       sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
@@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     checkAnswer(
       sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"),
       sql("SELECT * FROM testData2 ORDER BY b ASC"))
-
     checkAnswer(
       sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
       sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message