spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hvanhov...@apache.org
Subject spark git commit: [SPARK-14664][SQL] Implement DecimalAggregates optimization for Window queries
Date Wed, 27 Apr 2016 19:36:37 GMT
Repository: spark
Updated Branches:
  refs/heads/master c74fd1e54 -> af92299fd


[SPARK-14664][SQL] Implement DecimalAggregates optimization for Window queries

## What changes were proposed in this pull request?

This PR aims to implement decimal aggregation optimization for window queries by improving
existing `DecimalAggregates`. Historically, `DecimalAggregates` optimizer is designed to transform
general `sum/avg(decimal)`, but it breaks recently added windows queries like the followings.
The following queries work well without the current `DecimalAggregates` optimizer.

**Sum**
```scala
scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").head
java.lang.RuntimeException: Unsupported window function: MakeDecimal((sum(UnscaledValue(a#31)),mode=Complete,isDistinct=false),12,1)
scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [sum(a) OVER (  ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23]
:     +- INPUT
+- Window [MakeDecimal((sum(UnscaledValue(a#21)),mode=Complete,isDistinct=false),12,1) windowspecdefinition(ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS sum(a) OVER (  ROWS BETWEEN UNBOUNDED
PRECEDING AND UNBOUNDED FOLLOWING)#23]
   +- Exchange SinglePartition, None
      +- Generate explode([1.0,2.0]), false, false, [a#21]
         +- Scan OneRowRelation[]
```

**Average**
```scala
scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").head
java.lang.RuntimeException: Unsupported window function: cast(((avg(UnscaledValue(a#40)),mode=Complete,isDistinct=false)
/ 10.0) as decimal(6,5))
scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [avg(a) OVER (  ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44]
:     +- INPUT
+- Window [cast(((avg(UnscaledValue(a#42)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5))
windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS avg(a) OVER
(  ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44]
   +- Exchange SinglePartition, None
      +- Generate explode([1.0,2.0]), false, false, [a#42]
         +- Scan OneRowRelation[]
```

After this PR, those queries work fine and new optimized physical plans look like the followings.

**Sum**
```scala
scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [sum(a) OVER (  ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35]
:     +- INPUT
+- Window [MakeDecimal((sum(UnscaledValue(a#33)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),12,1) AS sum(a) OVER (  ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35]
   +- Exchange SinglePartition, None
      +- Generate explode([1.0,2.0]), false, false, [a#33]
         +- Scan OneRowRelation[]
```

**Average**
```scala
scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [avg(a) OVER (  ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47]
:     +- INPUT
+- Window [cast(((avg(UnscaledValue(a#45)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) / 10.0) as decimal(6,5)) AS avg(a) OVER
(  ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47]
   +- Exchange SinglePartition, None
      +- Generate explode([1.0,2.0]), false, false, [a#45]
         +- Scan OneRowRelation[]
```

In this PR, *SUM over window* pattern matching is based on the code of hvanhovell ; he should
be credited for the work he did.

## How was this patch tested?

Pass the Jenkins tests (with newly added testcases)

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12421 from dongjoon-hyun/SPARK-14664.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af92299f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af92299f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af92299f

Branch: refs/heads/master
Commit: af92299fdb8d358bbfeb2bf0d88ef7da8e2144b9
Parents: c74fd1e
Author: Dongjoon Hyun <dongjoon@apache.org>
Authored: Wed Apr 27 21:36:19 2016 +0200
Committer: Herman van Hovell <hvanhovell@questtec.nl>
Committed: Wed Apr 27 21:36:19 2016 +0200

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |  40 ++++--
 .../optimizer/DecimalAggregatesSuite.scala      | 122 +++++++++++++++++++
 .../spark/sql/DataFrameAggregateSuite.scala     |  11 +-
 3 files changed, 161 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af92299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b26ceba..54bf4a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1343,17 +1343,35 @@ object DecimalAggregates extends Rule[LogicalPlan] {
   /** Maximum number of decimal digits representable precisely in a Double */
   private val MAX_DOUBLE_DIGITS = 15
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
-      if prec + 10 <= MAX_LONG_DIGITS =>
-      MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
-
-    case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _,
_)
-      if prec + 4 <= MAX_DOUBLE_DIGITS =>
-      val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
-      Cast(
-        Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
-        DecimalType(prec + 4, scale + 4))
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case q: LogicalPlan => q transformExpressionsDown {
+      case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match
{
+        case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS
=>
+          MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
+            prec + 10, scale)
+
+        case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS
=>
+          val newAggExpr =
+            we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
+          Cast(
+            Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
+            DecimalType(prec + 4, scale + 4))
+
+        case _ => we
+      }
+      case ae @ AggregateExpression(af, _, _, _) => af match {
+        case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS
=>
+          MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
+
+        case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS
=>
+          val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
+          Cast(
+            Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
+            DecimalType(prec + 4, scale + 4))
+
+        case _ => ae
+      }
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/af92299f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
new file mode 100644
index 0000000..711294e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.DecimalType
+
+class DecimalAggregatesSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("Decimal Optimizations", FixedPoint(100),
+      DecimalAggregates) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1))
+
+  test("Decimal Sum Aggregation: Optimized") {
+    val originalQuery = testRelation.select(sum('a))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select(MakeDecimal(sum(UnscaledValue('a)), 12, 1).as("sum(a)")).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Sum Aggregation: Not Optimized") {
+    val originalQuery = testRelation.select(sum('b))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = originalQuery.analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Average Aggregation: Optimized") {
+    val originalQuery = testRelation.select(avg('a))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select((avg(UnscaledValue('a)) / 10.0).cast(DecimalType(6, 5)).as("avg(a)")).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Average Aggregation: Not Optimized") {
+    val originalQuery = testRelation.select(avg('b))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = originalQuery.analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Sum Aggregation over Window: Optimized") {
+    val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame)
+    val originalQuery = testRelation.select(windowExpr(sum('a), spec).as('sum_a))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select('a)
+      .window(
+        Seq(MakeDecimal(windowExpr(sum(UnscaledValue('a)), spec), 12, 1).as('sum_a)),
+        Seq('a),
+        Nil)
+      .select('a, 'sum_a, 'sum_a)
+      .select('sum_a)
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Sum Aggregation over Window: Not Optimized") {
+    val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame)
+    val originalQuery = testRelation.select(windowExpr(sum('b), spec))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = originalQuery.analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Average Aggregation over Window: Optimized") {
+    val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame)
+    val originalQuery = testRelation.select(windowExpr(avg('a), spec).as('avg_a))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .select('a)
+      .window(
+        Seq((windowExpr(avg(UnscaledValue('a)), spec) / 10.0).cast(DecimalType(6, 5)).as('avg_a)),
+        Seq('a),
+        Nil)
+      .select('a, 'avg_a, 'avg_a)
+      .select('avg_a)
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Decimal Average Aggregation over Window: Not Optimized") {
+    val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame)
+    val originalQuery = testRelation.select(windowExpr(avg('b), spec))
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = originalQuery.analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/af92299f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 2f685c5..63f4b75 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData.DecimalData
-import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.types.{Decimal, DecimalType}
 
 case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
 
@@ -430,4 +430,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
         expr("kurtosis(a)")),
       Row(null, null, null, null, null))
   }
+
+  test("SPARK-14664: Decimal sum/avg over window should work.") {
+    checkAnswer(
+      sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
+      Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil)
+    checkAnswer(
+      sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
+      Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message