spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-15776][SQL] Divide Expression inside Aggregation function is casted to wrong type
Date Wed, 15 Jun 2016 21:34:25 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 fe8ac729e -> 2c1aae442


[SPARK-15776][SQL] Divide Expression inside Aggregation function is casted to wrong type

## What changes were proposed in this pull request?

This PR fixes the problem that Divide Expression inside Aggregation function is casted to
wrong type, which cause `select 1/2` and `select sum(1/2)`returning different result.

**Before the change:**

```
scala> sql("select 1/2 as a").show()
+---+
|  a|
+---+
|0.5|
+---+

scala> sql("select sum(1/2) as a").show()
+---+
|  a|
+---+
|0  |
+---+

scala> sql("select sum(1 / 2) as a").schema
res4: org.apache.spark.sql.types.StructType = StructType(StructField(a,LongType,true))
```

**After the change:**

```
scala> sql("select 1/2 as a").show()
+---+
|  a|
+---+
|0.5|
+---+

scala> sql("select sum(1/2) as a").show()
+---+
|  a|
+---+
|0.5|
+---+

scala> sql("select sum(1/2) as a").schema
res4: org.apache.spark.sql.types.StructType = StructType(StructField(a,DoubleType,true))
```

## How was this patch tested?

Unit test.

This PR is based on https://github.com/apache/spark/pull/13524 by Sephiroth-Lin

Author: Sean Zhong <seanzhong@databricks.com>

Closes #13651 from clockfly/SPARK-15776.

(cherry picked from commit 9bd80ad6bd43462d16ce24cda77cdfaa336c4e02)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2c1aae44
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2c1aae44
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2c1aae44

Branch: refs/heads/branch-2.0
Commit: 2c1aae44218f5b6cb2087e2d5c074438351fb250
Parents: fe8ac72
Author: Sean Zhong <seanzhong@databricks.com>
Authored: Wed Jun 15 14:34:15 2016 -0700
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Wed Jun 15 14:34:21 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    |  8 +++--
 .../sql/catalyst/expressions/arithmetic.scala   |  3 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 32 +++++++++++++++++
 .../analysis/ExpressionTypeCheckingSuite.scala  |  2 +-
 .../catalyst/analysis/TypeCoercionSuite.scala   | 37 ++++++++++++++++++--
 .../expressions/ArithmeticExpressionSuite.scala | 19 +++++-----
 .../plans/ConstraintPropagationSuite.scala      |  4 +--
 7 files changed, 86 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index a5b5b91..16df628 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -525,14 +525,16 @@ object TypeCoercion {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
       // Skip nodes who has not been resolved yet,
       // as this is an extra rule which should be applied at last.
-      case e if !e.resolved => e
+      case e if !e.childrenResolved => e
 
       // Decimal and Double remain the same
       case d: Divide if d.dataType == DoubleType => d
       case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
-
-      case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType))
+      case Divide(left, right) if isNumeric(left) && isNumeric(right) =>
+        Divide(Cast(left, DoubleType), Cast(right, DoubleType))
     }
+
+    private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType]
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index b2df79a..4db1352 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -213,7 +213,7 @@ case class Multiply(left: Expression, right: Expression)
 case class Divide(left: Expression, right: Expression)
     extends BinaryArithmetic with NullIntolerant {
 
-  override def inputType: AbstractDataType = NumericType
+  override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
 
   override def symbol: String = "/"
   override def decimalMethod: String = "$div"
@@ -221,7 +221,6 @@ case class Divide(left: Expression, right: Expression)
 
   private lazy val div: (Any, Any) => Any = dataType match {
     case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
-    case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
   }
 
   override def eval(input: InternalRow): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 77ea29e..102c78b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -345,4 +345,36 @@ class AnalysisSuite extends AnalysisTest {
 
     assertAnalysisSuccess(query)
   }
+
+  private def assertExpressionType(
+      expression: Expression,
+      expectedDataType: DataType): Unit = {
+    val afterAnalyze =
+      Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head
+    if (!afterAnalyze.dataType.equals(expectedDataType)) {
+      fail(
+        s"""
+           |data type of expression $expression doesn't match expected:
+           |Actual data type:
+           |${afterAnalyze.dataType}
+           |
+           |Expected data type:
+           |${expectedDataType}
+         """.stripMargin)
+    }
+  }
+
+  test("SPARK-15776: test whether Divide expression's data type can be deduced correctly
by " +
+    "analyzer") {
+    assertExpressionType(sum(Divide(1, 2)), DoubleType)
+    assertExpressionType(sum(Divide(1.0, 2)), DoubleType)
+    assertExpressionType(sum(Divide(1, 2.0)), DoubleType)
+    assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
+    assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
+    assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
+    assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
+    assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
+    assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
+    assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 660dc86..54436ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
     assertError(Subtract('booleanField, 'booleanField),
       "requires (numeric or calendarinterval) type")
     assertError(Multiply('booleanField, 'booleanField), "requires numeric type")
-    assertError(Divide('booleanField, 'booleanField), "requires numeric type")
+    assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type")
     assertError(Remainder('booleanField, 'booleanField), "requires numeric type")
 
     assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type")

http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 7435399..971c99b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis
 
 import java.sql.Timestamp
 
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion}
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 
@@ -199,9 +201,20 @@ class TypeCoercionSuite extends PlanTest {
   }
 
   private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression)
{
+    ruleTest(Seq(rule), initial, transformed)
+  }
+
+  private def ruleTest(
+      rules: Seq[Rule[LogicalPlan]],
+      initial: Expression,
+      transformed: Expression): Unit = {
     val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+    val analyzer = new RuleExecutor[LogicalPlan] {
+      override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
+    }
+
     comparePlans(
-      rule(Project(Seq(Alias(initial, "a")()), testRelation)),
+      analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
       Project(Seq(Alias(transformed, "a")()), testRelation))
   }
 
@@ -630,6 +643,26 @@ class TypeCoercionSuite extends PlanTest {
         Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
     )
   }
+
+  test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal "
+
+    "in aggregation function like sum") {
+    val rules = Seq(FunctionArgumentConversion, Division)
+    // Casts Integer to Double
+    ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
+    // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts
will
+    // cast the right expression to Double.
+    ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3)))
+    // Left expression is Int, right expression is Double
+    ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType))))
+    // Casts Float to Double
+    ruleTest(
+      rules,
+      sum(Divide(4.0f, 3)),
+      sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType))))
+    // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision
will cast
+    // the right expression to Decimal.
+    ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 72285c6..2e37887 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     }
   }
 
+  private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
+    testFunc(_.toDouble)
+    testFunc(Decimal(_))
+  }
+
   test("/ (Divide) basic") {
-    testNumericDataTypes { convert =>
+    testDecimalAndDoubleType { convert =>
       val left = Literal(convert(2))
       val right = Literal(convert(1))
       val dataType = left.dataType
@@ -128,12 +133,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
       checkEvaluation(Divide(left, Literal(convert(0))), null)  // divide by zero
     }
 
-    DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
+    Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe =>
       checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe)
     }
   }
 
-  test("/ (Divide) for integral type") {
+  // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType.
+  // TODO: in future release, we should add a IntegerDivide to support integral types.
+  ignore("/ (Divide) for integral type") {
     checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
     checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
     checkEvaluation(Divide(Literal(1), Literal(2)), 0)
@@ -143,12 +150,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L)
   }
 
-  test("/ (Divide) for floating point") {
-    checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
-    checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
-    checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5))
-  }
-
   test("% (Remainder)") {
     testNumericDataTypes { convert =>
       val left = Literal(convert(1))

http://git-wip-us.apache.org/repos/asf/spark/blob/2c1aae44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 81cc6b1..0b73b5e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -298,7 +298,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
         Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType)
===
           Cast(resolveColumn(tr, "c"), LongType),
         Cast(resolveColumn(tr, "d"), DoubleType) /
-          Cast(Cast(10, LongType), DoubleType) ===
+          Cast(10, DoubleType) ===
             Cast(resolveColumn(tr, "e"), DoubleType),
         IsNotNull(resolveColumn(tr, "a")),
         IsNotNull(resolveColumn(tr, "b")),
@@ -312,7 +312,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
         Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType)
>=
           Cast(resolveColumn(tr, "c"), LongType),
         Cast(resolveColumn(tr, "d"), DoubleType) /
-          Cast(Cast(10, LongType), DoubleType) <
+          Cast(10, DoubleType) <
             Cast(resolveColumn(tr, "e"), DoubleType),
         IsNotNull(resolveColumn(tr, "a")),
         IsNotNull(resolveColumn(tr, "b")),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message