spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-18969][SQL] Support grouping by nondeterministic expressions
Date Thu, 12 Jan 2017 12:24:36 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.1 9b9867ef5 -> 616a78a56


[SPARK-18969][SQL] Support grouping by nondeterministic expressions

## What changes were proposed in this pull request?

Currently nondeterministic expressions are allowed in `Aggregate`(see the [comment](https://github.com/apache/spark/blob/v2.0.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L249-L251)),
but the `PullOutNondeterministic` analyzer rule failed to handle `Aggregate`, this PR fixes
it.

close https://github.com/apache/spark/pull/16379

There is still one remaining issue: `SELECT a + rand() FROM t GROUP BY a + rand()` is not
allowed, because the 2 `rand()` are different(we generate random seed as the default seed
for `rand()`). https://issues.apache.org/jira/browse/SPARK-19035 is tracking this issue.

## How was this patch tested?

a new test suite

Author: Wenchen Fan <wenchen@databricks.com>

Closes #16404 from cloud-fan/groupby.

(cherry picked from commit 871d266649ddfed38c64dfda7158d8bb58d4b979)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/616a78a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/616a78a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/616a78a5

Branch: refs/heads/branch-2.1
Commit: 616a78a56cc911953e3133e60ab8c5a4fc287539
Parents: 9b9867e
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Thu Jan 12 20:21:04 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu Jan 12 20:24:23 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 37 ++++++++-----
 .../analysis/PullOutNondeterministicSuite.scala | 56 ++++++++++++++++++++
 .../sql-tests/results/group-by-ordinal.sql.out  | 10 ++--
 3 files changed, 86 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/616a78a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f17c372..ab9de02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1859,28 +1859,37 @@ class Analyzer(
       case p: Project => p
       case f: Filter => f
 
+      case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
+        val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
+        val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child)
+        a.transformExpressions { case e =>
+          nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
+        }.copy(child = newChild)
+
       // todo: It's hard to write a general rule to pull out nondeterministic expressions
       // from LogicalPlan, currently we only do it for UnaryNode which has same output
       // schema with its child.
       case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic)
=>
-        val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr
=>
-          val leafNondeterministic = expr.collect {
-            case n: Nondeterministic => n
-          }
-          leafNondeterministic.map { e =>
-            val ne = e match {
-              case n: NamedExpression => n
-              case _ => Alias(e, "_nondeterministic")(isGenerated = true)
-            }
-            new TreeNodeRef(e) -> ne
-          }
-        }.toMap
+        val nondeterToAttr = getNondeterToAttr(p.expressions)
         val newPlan = p.transformExpressions { case e =>
-          nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
+          nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
         }
-        val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
+        val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child)
         Project(p.output, newPlan.withNewChildren(newChild :: Nil))
     }
+
+    private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression]
= {
+      exprs.filterNot(_.deterministic).flatMap { expr =>
+        val leafNondeterministic = expr.collect { case n: Nondeterministic => n }
+        leafNondeterministic.distinct.map { e =>
+          val ne = e match {
+            case n: NamedExpression => n
+            case _ => Alias(e, "_nondeterministic")(isGenerated = true)
+          }
+          e -> ne
+        }
+      }.toMap
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/616a78a5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
new file mode 100644
index 0000000..72e10ea
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+/**
+ * Test suite for moving non-deterministic expressions into Project.
+ */
+class PullOutNondeterministicSuite extends AnalysisTest {
+
+  private lazy val a = 'a.int
+  private lazy val b = 'b.int
+  private lazy val r = LocalRelation(a, b)
+  private lazy val rnd = Rand(10).as('_nondeterministic)
+  private lazy val rndref = rnd.toAttribute
+
+  test("no-op on filter") {
+    checkAnalysis(
+      r.where(Rand(10) > Literal(1.0)),
+      r.where(Rand(10) > Literal(1.0))
+    )
+  }
+
+  test("sort") {
+    checkAnalysis(
+      r.sortBy(SortOrder(Rand(10), Ascending)),
+      r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b)
+    )
+  }
+
+  test("aggregate") {
+    checkAnalysis(
+      r.groupBy(Rand(10))(Rand(10).as("rnd")),
+      r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd"))
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/616a78a5/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
index 9c3a145..c64520f 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -137,10 +137,14 @@ GROUP BY position 3 is an aggregate function, and aggregate functions
are not al
 -- !query 13
 select a, rand(0), sum(b) from data group by a, 2
 -- !query 13 schema
-struct<>
+struct<a:int,rand(0):double,sum(b):bigint>
 -- !query 13 output
-org.apache.spark.sql.AnalysisException
-nondeterministic expression rand(0) should not appear in grouping expression.;
+1	0.4048454303385226	2
+1	0.8446490682263027	1
+2	0.5871875724155838	1
+2	0.8865128837019473	2
+3	0.742083829230211	1
+3	0.9179913208300406	2
 
 
 -- !query 14


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message