spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-9172][SQL] Make DecimalPrecision support for Intersect and Except
Date Mon, 20 Jul 2015 03:53:23 GMT
Repository: spark
Updated Branches:
  refs/heads/master 93eb2acfb -> d743bec64


[SPARK-9172][SQL] Make DecimalPrecision support for Intersect and Except

JIRA: https://issues.apache.org/jira/browse/SPARK-9172

Simply make `DecimalPrecision` support for `Intersect` and `Except` in addition to `Union`.

Besides, add unit test for `DecimalPrecision` as well.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #7511 from viirya/more_decimalprecieion and squashes the following commits:

4d29d10 [Liang-Chi Hsieh] Fix code comment.
9fb0d49 [Liang-Chi Hsieh] Make DecimalPrecision support for Intersect and Except.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d743bec6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d743bec6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d743bec6

Branch: refs/heads/master
Commit: d743bec645fd2a65bd488d2d660b3aa2135b4da6
Parents: 93eb2ac
Author: Liang-Chi Hsieh <viirya@appier.com>
Authored: Sun Jul 19 20:53:18 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Sun Jul 19 20:53:18 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 86 +++++++++++---------
 .../analysis/HiveTypeCoercionSuite.scala        | 55 +++++++++++++
 2 files changed, 104 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d743bec6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index ff20835..e214545 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -335,7 +335,7 @@ object HiveTypeCoercion {
    * - INT gets turned into DECIMAL(10, 0)
    * - LONG gets turned into DECIMAL(20, 0)
    * - FLOAT and DOUBLE
-   *   1. Union operation:
+   *   1. Union, Intersect and Except operations:
    *      FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this
is the
    *      same as Hive)
    *   2. Other operation:
@@ -362,47 +362,59 @@ object HiveTypeCoercion {
       DoubleType -> DecimalType(15, 15)
     )
 
-    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-      // fix decimal precision for union
-      case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
-        val castedInput = left.output.zip(right.output).map {
-          case (lhs, rhs) if lhs.dataType != rhs.dataType =>
-            (lhs.dataType, rhs.dataType) match {
-              case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
-                // Union decimals with precision/scale p1/s2 and p2/s2  will be promoted
to
-                // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
-                val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1,
s2))
-                (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
-              case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
-                (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
-              case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
-                (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
-              case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
-                (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
-              case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
-                (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
-              case _ => (lhs, rhs)
-            }
-          case other => other
-        }
+    private def castDecimalPrecision(
+        left: LogicalPlan,
+        right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
+      val castedInput = left.output.zip(right.output).map {
+        case (lhs, rhs) if lhs.dataType != rhs.dataType =>
+          (lhs.dataType, rhs.dataType) match {
+            case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
+              // Decimals with precision/scale p1/s2 and p2/s2  will be promoted to
+              // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
+              val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
+              (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
+            case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+              (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
+            case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+              (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
+            case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
+              (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
+            case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
+              (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
+            case _ => (lhs, rhs)
+          }
+        case other => other
+      }
 
-        val (castedLeft, castedRight) = castedInput.unzip
+      val (castedLeft, castedRight) = castedInput.unzip
 
-        val newLeft =
-          if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
-            Project(castedLeft, left)
-          } else {
-            left
-          }
+      val newLeft =
+        if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
+          Project(castedLeft, left)
+        } else {
+          left
+        }
 
-        val newRight =
-          if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
-            Project(castedRight, right)
-          } else {
-            right
-          }
+      val newRight =
+        if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
+          Project(castedRight, right)
+        } else {
+          right
+        }
+      (newLeft, newRight)
+    }
 
+    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+      // fix decimal precision for union, intersect and except
+      case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
+        val (newLeft, newRight) = castDecimalPrecision(left, right)
         Union(newLeft, newRight)
+      case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
+        val (newLeft, newRight) = castDecimalPrecision(left, right)
+        Intersect(newLeft, newRight)
+      case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
+        val (newLeft, newRight) = castDecimalPrecision(left, right)
+        Except(newLeft, newRight)
 
       // fix decimal precision for expressions
       case q => q.transformExpressions {

http://git-wip-us.apache.org/repos/asf/spark/blob/d743bec6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 7ee2333..835220c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -346,6 +346,61 @@ class HiveTypeCoercionSuite extends PlanTest {
     checkOutput(r3.right, expectedTypes)
   }
 
+  test("Transform Decimal precision/scale for union except and intersect") {
+    def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+      logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+        assert(attr.dataType === dt)
+      }
+    }
+
+    val dp = HiveTypeCoercion.DecimalPrecision
+
+    val left1 = LocalRelation(
+      AttributeReference("l", DecimalType(10, 8))())
+    val right1 = LocalRelation(
+      AttributeReference("r", DecimalType(5, 5))())
+    val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8,
5)))
+
+    val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
+    val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
+    val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
+
+    checkOutput(r1.left, expectedType1)
+    checkOutput(r1.right, expectedType1)
+    checkOutput(r2.left, expectedType1)
+    checkOutput(r2.right, expectedType1)
+    checkOutput(r3.left, expectedType1)
+    checkOutput(r3.right, expectedType1)
+
+    val plan1 = LocalRelation(
+      AttributeReference("l", DecimalType(10, 10))())
+
+    val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
+    val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0),
+      DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15))
+
+    rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
+      val plan2 = LocalRelation(
+        AttributeReference("r", rType)())
+
+      val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
+      val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
+      val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
+
+      checkOutput(r1.right, Seq(expectedType))
+      checkOutput(r2.right, Seq(expectedType))
+      checkOutput(r3.right, Seq(expectedType))
+
+      val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
+      val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
+      val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
+
+      checkOutput(r4.left, Seq(expectedType))
+      checkOutput(r5.left, Seq(expectedType))
+      checkOutput(r6.left, Seq(expectedType))
+    }
+  }
+
   /**
    * There are rules that need to not fire before child expressions get resolved.
    * We use this test to make sure those rules do not fire early.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message