spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-18622][SQL] Fix the datatype of the Sum aggregate function
Date Wed, 30 Nov 2016 07:25:41 GMT
Repository: spark
Updated Branches:
  refs/heads/master a1d9138ab -> 879ba7111


[SPARK-18622][SQL] Fix the datatype of the Sum aggregate function

## What changes were proposed in this pull request?
The result of a `sum` aggregate function is typically a Decimal, Double or a Long. Currently
the output dataType is based on input's dataType.

The `FunctionArgumentConversion` rule will make sure that the input is promoted to the largest
type, and that also ensures that the output uses a (hopefully) sufficiently large output dataType.
The issue is that sum is in a resolved state when we cast the input type, this means that
rules assuming that the dataType of the expression does not change anymore could have been
applied in the mean time. This is what happens if we apply `WidenSetOperationTypes` before
applying the casts, and this breaks analysis.

The most straight forward and future proof solution is to make `sum` always output the widest
dataType in its class (Long for IntegralTypes, Decimal for DecimalTypes & Double for FloatType
and DoubleType). This PR implements that solution.

We should move expression specific type casting rules into the given Expression at some point.

## How was this patch tested?
Added (regression) tests to SQLQueryTestSuite's `union.sql`.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #16063 from hvanhovell/SPARK-18622.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/879ba711
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/879ba711
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/879ba711

Branch: refs/heads/master
Commit: 879ba71110b6c85a4e47133620fbae7580650a6f
Parents: a1d9138
Author: Herman van Hovell <hvanhovell@databricks.com>
Authored: Wed Nov 30 15:25:33 2016 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Wed Nov 30 15:25:33 2016 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Sum.scala    |  6 +-
 .../test/resources/sql-tests/inputs/union.sql   | 27 +++++++
 .../resources/sql-tests/results/union.sql.out   | 80 ++++++++++++++++++++
 3 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/879ba711/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 96e8cee..86e40a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -33,8 +33,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
   // Return data type.
   override def dataType: DataType = resultType
 
-  override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(LongType, DoubleType, DecimalType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, "function sum")
@@ -42,7 +41,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
   private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
-    case _ => child.dataType
+    case _: IntegralType => LongType
+    case _ => DoubleType
   }
 
   private lazy val sumDataType = resultType

http://git-wip-us.apache.org/repos/asf/spark/blob/879ba711/sql/core/src/test/resources/sql-tests/inputs/union.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql
new file mode 100644
index 0000000..1f4780a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql
@@ -0,0 +1,27 @@
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2);
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2);
+
+-- Simple Union
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t1);
+
+-- Type Coerced Union
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t2
+        UNION ALL
+        SELECT * FROM t2);
+
+-- Regression test for SPARK-18622
+SELECT a
+FROM (SELECT 0 a, 0 b
+      UNION ALL
+      SELECT SUM(1) a, CAST(0 AS BIGINT) b
+      UNION ALL SELECT 0 a, 0 b) T;
+
+-- Clean-up
+DROP VIEW IF EXISTS t1;
+DROP VIEW IF EXISTS t2;

http://git-wip-us.apache.org/repos/asf/spark/blob/879ba711/sql/core/src/test/resources/sql-tests/results/union.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out
new file mode 100644
index 0000000..c57028c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out
@@ -0,0 +1,80 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t1)
+-- !query 2 schema
+struct<c1:int,c2:string>
+-- !query 2 output
+1	a
+1	a
+2	b
+2	b
+
+
+-- !query 3
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t2
+        UNION ALL
+        SELECT * FROM t2)
+-- !query 3 schema
+struct<c1:decimal(11,1),c2:string>
+-- !query 3 output
+1	1
+1	1
+1	a
+2	4
+2	4
+2	b
+
+
+-- !query 4
+SELECT a
+FROM (SELECT 0 a, 0 b
+      UNION ALL
+      SELECT SUM(1) a, CAST(0 AS BIGINT) b
+      UNION ALL SELECT 0 a, 0 b) T
+-- !query 4 schema
+struct<a:bigint>
+-- !query 4 output
+0
+0
+1
+
+
+-- !query 5
+DROP VIEW IF EXISTS t1
+-- !query 5 schema
+struct<>
+-- !query 5 output
+
+
+
+-- !query 6
+DROP VIEW IF EXISTS t2
+-- !query 6 schema
+struct<>
+-- !query 6 output
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message