spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions
Date Mon, 04 Jun 2018 20:27:53 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0be5aa274 -> 7297ae04d


[SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions

## What changes were proposed in this pull request?

This PR explicitly prohibits window functions inside aggregates. Currently, this will cause
StackOverflow during analysis. See PR #19193 for previous discussion.

## How was this patch tested?

This PR comes with a dedicated unit test.

Author: aokolnychyi <anton.okolnychyi@sap.com>

Closes #21473 from aokolnychyi/fix-stackoverflow-window-funcs.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7297ae04
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7297ae04
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7297ae04

Branch: refs/heads/master
Commit: 7297ae04d87b6e3d48b747a7c1d53687fcc3971c
Parents: 0be5aa2
Author: aokolnychyi <anton.okolnychyi@sap.com>
Authored: Mon Jun 4 13:28:16 2018 -0700
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Mon Jun 4 13:28:16 2018 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 10 ++++--
 .../spark/sql/DataFrameAggregateSuite.scala     | 34 ++++++++++++++++++--
 2 files changed, 39 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7297ae04/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3eaa9ec..f9947d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1744,10 +1744,10 @@ class Analyzer(
    *    it into the plan tree.
    */
   object ExtractWindowExpressions extends Rule[LogicalPlan] {
-    private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
-      projectList.exists(hasWindowFunction)
+    private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
+      exprs.exists(hasWindowFunction)
 
-    private def hasWindowFunction(expr: NamedExpression): Boolean = {
+    private def hasWindowFunction(expr: Expression): Boolean = {
       expr.find {
         case window: WindowExpression => true
         case _ => false
@@ -1830,6 +1830,10 @@ class Analyzer(
             seenWindowAggregates += newAgg
             WindowExpression(newAgg, spec)
 
+          case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children)
=>
+            failAnalysis("It is not allowed to use a window function inside an aggregate
" +
+              "function. Please use the inner window function in a sub-query.")
+
           // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
           // we need to extract SUM(x).
           case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7297ae04/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 96c2896..f495a94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql
 
 import scala.util.Random
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.scalatest.Matchers.the
+
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec,
SortAggregateExec}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
       }
     }
   }
+
+  test("SPARK-21896: Window functions inside aggregate functions") {
+    def checkWindowError(df: => DataFrame): Unit = {
+      val thrownException = the [AnalysisException] thrownBy {
+        df.queryExecution.analyzed
+      }
+      assert(thrownException.message.contains("not allowed to use a window function"))
+    }
+
+    checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
+    checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
+    checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
+    checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a)))))
+    checkWindowError(
+      testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3))
+    checkAnswer(
+      testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s ===
3),
+      Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil)
+
+    checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2"))
+    checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
+    checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP
BY a"))
+    checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY
a"))
+    checkWindowError(
+      sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) =
3"))
+    checkAnswer(
+      sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b)
= 3"),
+      Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message