spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject git commit: [SPARK-2860][SQL] Fix coercion of CASE WHEN.
Date Tue, 05 Aug 2014 18:18:04 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.1 e3fe6571d -> 388ab534b


[SPARK-2860][SQL] Fix coercion of CASE WHEN.

Author: Michael Armbrust <michael@databricks.com>

Closes #1785 from marmbrus/caseNull and squashes the following commits:

126006d [Michael Armbrust] better error message
2fe357f [Michael Armbrust] Fix coercion of CASE WHEN.

(cherry picked from commit 6e821e3d1ae1ed23459bc7f1098510b968130152)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/388ab534
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/388ab534
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/388ab534

Branch: refs/heads/branch-1.1
Commit: 388ab534b318e6736484a2fab6f88390abbf8c55
Parents: e3fe657
Author: Michael Armbrust <michael@databricks.com>
Authored: Tue Aug 5 11:17:50 2014 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Tue Aug 5 11:18:00 2014 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 56 +++++++++++---------
 ...null case-0-581cdfe70091e546414b202da2cebdcb |  1 +
 .../sql/hive/execution/HiveQuerySuite.scala     |  3 ++
 3 files changed, 36 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/388ab534/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e94f2a3..15eb598 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -49,10 +49,21 @@ trait HiveTypeCoercion {
     BooleanCasts ::
     StringToIntegralCasts ::
     FunctionArgumentConversion ::
-    CastNulls ::
+    CaseWhenCoercion ::
     Division ::
     Nil
 
+  trait TypeWidening {
+    def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
+      // Try and find a promotion rule that contains both types in question.
+      val applicableConversion =
+        HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
+
+      // If found return the widest common type, otherwise None
+      applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+    }
+  }
+
   /**
    * Applies any changes to [[AttributeReference]] data types that are made by other rules
to
    * instances higher in the query tree.
@@ -133,16 +144,7 @@ trait HiveTypeCoercion {
    * - LongType to FloatType
    * - LongType to DoubleType
    */
-  object WidenTypes extends Rule[LogicalPlan] {
-
-    def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
-      // Try and find a promotion rule that contains both types in question.
-      val applicableConversion =
-        HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
-      // If found return the widest common type, otherwise None
-      applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
-    }
+  object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
 
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
       case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -336,28 +338,34 @@ trait HiveTypeCoercion {
   }
 
   /**
-   * Ensures that NullType gets casted to some other types under certain circumstances.
+   * Coerces the type of different branches of a CASE WHEN statement to a common type.
    */
-  object CastNulls extends Rule[LogicalPlan] {
+  object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
     def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-      case cw @ CaseWhen(branches) =>
+      case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved)
 =>
         val valueTypes = branches.sliding(2, 2).map {
-          case Seq(_, value) if value.resolved => Some(value.dataType)
-          case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType)
-          case _ => None
+          case Seq(_, value) => value.dataType
+          case Seq(elseVal) => elseVal.dataType
         }.toSeq
-        if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType)))
{
-          val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get
+
+        logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
+
+        if (valueTypes.distinct.size > 1) {
+          val commonType = valueTypes.reduce { (v1, v2) =>
+            findTightestCommonType(v1, v2)
+              .getOrElse(sys.error(
+                s"Types in CASE WHEN must be the same or coercible to a common type: $v1
!= $v2"))
+          }
           val transformedBranches = branches.sliding(2, 2).map {
-            case Seq(cond, value) if value.resolved && value.dataType == NullType
=>
-              Seq(cond, Cast(value, otherType))
-            case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType
=>
-              Seq(Cast(elseVal, otherType))
+            case Seq(cond, value) if value.dataType != commonType =>
+              Seq(cond, Cast(value, commonType))
+            case Seq(elseVal) if elseVal.dataType != commonType =>
+              Seq(Cast(elseVal, commonType))
             case s => s
           }.reduce(_ ++ _)
           CaseWhen(transformedBranches)
         } else {
-          // It is possible to have more types due to the possibility of short-circuiting.
+          // Types match up.  Hopefully some other rule fixes whatever is wrong with resolution.
           cw
         }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/388ab534/sql/hive/src/test/resources/golden/null
case-0-581cdfe70091e546414b202da2cebdcb
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb
b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb
new file mode 100644
index 0000000..d00491f
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/null case-0-581cdfe70091e546414b202da2cebdcb	
@@ -0,0 +1 @@
+1

http://git-wip-us.apache.org/repos/asf/spark/blob/388ab534/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index aa810a2..2f0be49 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -32,6 +32,9 @@ case class TestData(a: Int, b: String)
  */
 class HiveQuerySuite extends HiveComparisonTest {
 
+  createQueryTest("null case",
+    "SELECT case when(true) then 1 else null end FROM src LIMIT 1")
+
   createQueryTest("single case",
     """SELECT case when true then 1 else 2 end FROM src LIMIT 1""")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message