spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject git commit: [SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation
Date Mon, 13 Oct 2014 20:36:47 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2ac40da3f -> 56102dc2d


[SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation

This PR adds a new rule `CheckAggregation` to the analyzer to provide better error message
for non-aggregate attributes with aggregation.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #2774 from liancheng/non-aggregate-attr and squashes the following commits:

5246004 [Cheng Lian] Passes test suites
bf1878d [Cheng Lian] Adds checks for non-aggregate attributes with aggregation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56102dc2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56102dc2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56102dc2

Branch: refs/heads/master
Commit: 56102dc2d849c221f325a7888cd66abb640ec887
Parents: 2ac40da
Author: Cheng Lian <lian.cs.zju@gmail.com>
Authored: Mon Oct 13 13:36:39 2014 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Mon Oct 13 13:36:39 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 36 +++++++++++++++++---
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 26 ++++++++++++++
 2 files changed, 57 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56102dc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fe83eb1..8255306 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive:
Bool
       typeCoercionRules ++
       extendedRules : _*),
     Batch("Check Analysis", Once,
-      CheckResolution),
+      CheckResolution,
+      CheckAggregation),
     Batch("AnalysisOperators", fixedPoint,
       EliminateAnalysisOperators)
   )
@@ -89,6 +90,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive:
Bool
   }
 
   /**
+   * Checks for non-aggregated attributes with aggregation
+   */
+  object CheckAggregation extends Rule[LogicalPlan] {
+    def apply(plan: LogicalPlan): LogicalPlan = {
+      plan.transform {
+        case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
+          def isValidAggregateExpression(expr: Expression): Boolean = expr match {
+            case _: AggregateExpression => true
+            case e: Attribute => groupingExprs.contains(e)
+            case e if groupingExprs.contains(e) => true
+            case e if e.references.isEmpty => true
+            case e => e.children.forall(isValidAggregateExpression)
+          }
+
+          aggregateExprs.foreach { e =>
+            if (!isValidAggregateExpression(e)) {
+              throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
+            }
+          }
+
+          aggregatePlan
+      }
+    }
+  }
+
+  /**
    * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
    */
   object ResolveRelations extends Rule[LogicalPlan] {
@@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive:
Bool
    */
   object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-      case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))

+      case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
           if aggregate.resolved && containsAggregate(havingCondition) => {
         val evaluatedCondition = Alias(havingCondition,  "havingCondition")()
         val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
-        
+
         Project(aggregate.output,
           Filter(evaluatedCondition.toAttribute,
             aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
       }
-      
     }
-    
+
     protected def containsAggregate(condition: Expression): Boolean =
       condition
         .collect { case ae: AggregateExpression => ae }

http://git-wip-us.apache.org/repos/asf/spark/blob/56102dc2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a94022c..15f6ba4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.joins.BroadcastHashJoin
 import org.apache.spark.sql.test._
 import org.scalatest.BeforeAndAfterAll
@@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
     checkAnswer(
       sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by
key"), 1)
   }
+
+  test("throw errors for non-aggregate attributes with aggregation") {
+    def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
+      val logicalPlan = sql(query).queryExecution.logical
+
+      if (isInvalidQuery) {
+        val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
+        assert(
+          e.getMessage.startsWith("Expression not in GROUP BY"),
+          "Non-aggregate attribute(s) not detected\n" + logicalPlan)
+      } else {
+        // Should not throw
+        sql(query).queryExecution.analyzed
+      }
+    }
+
+    checkAggregation("SELECT key, COUNT(*) FROM testData")
+    checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
+
+    checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
+    checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
+
+    checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
+    checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message