spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [1/2] spark git commit: [SPARK-15310][SQL] Rename HiveTypeCoercion -> TypeCoercion
Date Fri, 13 May 2016 07:15:47 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 86b8f8a9a -> 43570c576


http://git-wip-us.apache.org/repos/asf/spark/blob/43570c57/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
deleted file mode 100644
index b591861..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ /dev/null
@@ -1,663 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.analysis
-
-import java.sql.Timestamp
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
-
-class HiveTypeCoercionSuite extends PlanTest {
-
-  test("eligible implicit type cast") {
-    def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
-      val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
-      assert(got.map(_.dataType) == Option(expected),
-        s"Failed to cast $from to $to")
-    }
-
-    shouldCast(NullType, NullType, NullType)
-    shouldCast(NullType, IntegerType, IntegerType)
-    shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT)
-
-    shouldCast(ByteType, IntegerType, IntegerType)
-    shouldCast(IntegerType, IntegerType, IntegerType)
-    shouldCast(IntegerType, LongType, LongType)
-    shouldCast(IntegerType, DecimalType, DecimalType(10, 0))
-    shouldCast(LongType, IntegerType, IntegerType)
-    shouldCast(LongType, DecimalType, DecimalType(20, 0))
-
-    shouldCast(DateType, TimestampType, TimestampType)
-    shouldCast(TimestampType, DateType, DateType)
-
-    shouldCast(StringType, IntegerType, IntegerType)
-    shouldCast(StringType, DateType, DateType)
-    shouldCast(StringType, TimestampType, TimestampType)
-    shouldCast(IntegerType, StringType, StringType)
-    shouldCast(DateType, StringType, StringType)
-    shouldCast(TimestampType, StringType, StringType)
-
-    shouldCast(StringType, BinaryType, BinaryType)
-    shouldCast(BinaryType, StringType, StringType)
-
-    shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
-
-    shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
-    shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
-    shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
-
-    shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType)
-    shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
-    shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
-    shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
-
-    shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
-    shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
-
-    shouldCast(DecimalType.SYSTEM_DEFAULT,
-      TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
-    shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
-    shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
-    shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
-
-    shouldCast(StringType, NumericType, DoubleType)
-    shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType)
-
-    // NumericType should not be changed when function accepts any of them.
-    Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
-      DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe =>
-      shouldCast(tpe, NumericType, tpe)
-    }
-
-    shouldCast(
-      ArrayType(StringType, false),
-      TypeCollection(ArrayType(StringType), StringType),
-      ArrayType(StringType, false))
-
-    shouldCast(
-      ArrayType(StringType, true),
-      TypeCollection(ArrayType(StringType), StringType),
-      ArrayType(StringType, true))
-  }
-
-  test("ineligible implicit type cast") {
-    def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
-      val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
-      assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
-    }
-
-    shouldNotCast(IntegerType, DateType)
-    shouldNotCast(IntegerType, TimestampType)
-    shouldNotCast(LongType, DateType)
-    shouldNotCast(LongType, TimestampType)
-    shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType)
-    shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType)
-
-    shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
-
-    shouldNotCast(IntegerType, ArrayType)
-    shouldNotCast(IntegerType, MapType)
-    shouldNotCast(IntegerType, StructType)
-
-    shouldNotCast(CalendarIntervalType, StringType)
-
-    // Don't implicitly cast complex types to string.
-    shouldNotCast(ArrayType(StringType), StringType)
-    shouldNotCast(MapType(StringType, StringType), StringType)
-    shouldNotCast(new StructType().add("a1", StringType), StringType)
-    shouldNotCast(MapType(StringType, StringType), StringType)
-  }
-
-  test("tightest common bound for types") {
-    def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
-      var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2)
-      assert(found == tightestCommon,
-        s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
-      // Test both directions to make sure the widening is symmetric.
-      found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1)
-      assert(found == tightestCommon,
-        s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
-    }
-
-    // Null
-    widenTest(NullType, NullType, Some(NullType))
-
-    // Boolean
-    widenTest(NullType, BooleanType, Some(BooleanType))
-    widenTest(BooleanType, BooleanType, Some(BooleanType))
-    widenTest(IntegerType, BooleanType, None)
-    widenTest(LongType, BooleanType, None)
-
-    // Integral
-    widenTest(NullType, ByteType, Some(ByteType))
-    widenTest(NullType, IntegerType, Some(IntegerType))
-    widenTest(NullType, LongType, Some(LongType))
-    widenTest(ShortType, IntegerType, Some(IntegerType))
-    widenTest(ShortType, LongType, Some(LongType))
-    widenTest(IntegerType, LongType, Some(LongType))
-    widenTest(LongType, LongType, Some(LongType))
-
-    // Floating point
-    widenTest(NullType, FloatType, Some(FloatType))
-    widenTest(NullType, DoubleType, Some(DoubleType))
-    widenTest(FloatType, DoubleType, Some(DoubleType))
-    widenTest(FloatType, FloatType, Some(FloatType))
-    widenTest(DoubleType, DoubleType, Some(DoubleType))
-
-    // Integral mixed with floating point.
-    widenTest(IntegerType, FloatType, Some(FloatType))
-    widenTest(IntegerType, DoubleType, Some(DoubleType))
-    widenTest(IntegerType, DoubleType, Some(DoubleType))
-    widenTest(LongType, FloatType, Some(FloatType))
-    widenTest(LongType, DoubleType, Some(DoubleType))
-
-    // No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
-    widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
-    widenTest(DecimalType(2, 1), DoubleType, None)
-    widenTest(DecimalType(2, 1), IntegerType, None)
-    widenTest(DoubleType, DecimalType(2, 1), None)
-    widenTest(IntegerType, DecimalType(2, 1), None)
-
-    // StringType
-    widenTest(NullType, StringType, Some(StringType))
-    widenTest(StringType, StringType, Some(StringType))
-    widenTest(IntegerType, StringType, None)
-    widenTest(LongType, StringType, None)
-
-    // TimestampType
-    widenTest(NullType, TimestampType, Some(TimestampType))
-    widenTest(TimestampType, TimestampType, Some(TimestampType))
-    widenTest(IntegerType, TimestampType, None)
-    widenTest(StringType, TimestampType, None)
-
-    // ComplexType
-    widenTest(NullType,
-      MapType(IntegerType, StringType, false),
-      Some(MapType(IntegerType, StringType, false)))
-    widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
-    widenTest(StringType, MapType(IntegerType, StringType, true), None)
-    widenTest(ArrayType(IntegerType), StructType(Seq()), None)
-  }
-
-  private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
-    val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
-    comparePlans(
-      rule(Project(Seq(Alias(initial, "a")()), testRelation)),
-      Project(Seq(Alias(transformed, "a")()), testRelation))
-  }
-
-  test("cast NullType for expressions that implement ExpectsInputTypes") {
-    import HiveTypeCoercionSuite._
-
-    ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
-      AnyTypeUnaryExpression(Literal.create(null, NullType)),
-      AnyTypeUnaryExpression(Literal.create(null, NullType)))
-
-    ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
-      NumericTypeUnaryExpression(Literal.create(null, NullType)),
-      NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
-  }
-
-  test("cast NullType for binary operators") {
-    import HiveTypeCoercionSuite._
-
-    ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
-      AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
-      AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
-
-    ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
-      NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
-      NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
-  }
-
-  test("coalesce casts") {
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      Coalesce(Literal(1.0)
-        :: Literal(1)
-        :: Literal.create(1.0, FloatType)
-        :: Nil),
-      Coalesce(Cast(Literal(1.0), DoubleType)
-        :: Cast(Literal(1), DoubleType)
-        :: Cast(Literal.create(1.0, FloatType), DoubleType)
-        :: Nil))
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      Coalesce(Literal(1L)
-        :: Literal(1)
-        :: Literal(new java.math.BigDecimal("1000000000000000000000"))
-        :: Nil),
-      Coalesce(Cast(Literal(1L), DecimalType(22, 0))
-        :: Cast(Literal(1), DecimalType(22, 0))
-        :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
-        :: Nil))
-  }
-
-  test("CreateArray casts") {
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      CreateArray(Literal(1.0)
-        :: Literal(1)
-        :: Literal.create(1.0, FloatType)
-        :: Nil),
-      CreateArray(Cast(Literal(1.0), DoubleType)
-        :: Cast(Literal(1), DoubleType)
-        :: Cast(Literal.create(1.0, FloatType), DoubleType)
-        :: Nil))
-
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      CreateArray(Literal(1.0)
-        :: Literal(1)
-        :: Literal("a")
-        :: Nil),
-      CreateArray(Cast(Literal(1.0), StringType)
-        :: Cast(Literal(1), StringType)
-        :: Cast(Literal("a"), StringType)
-        :: Nil))
-  }
-
-  test("CreateMap casts") {
-    // type coercion for map keys
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      CreateMap(Literal(1)
-        :: Literal("a")
-        :: Literal.create(2.0, FloatType)
-        :: Literal("b")
-        :: Nil),
-      CreateMap(Cast(Literal(1), FloatType)
-        :: Literal("a")
-        :: Cast(Literal.create(2.0, FloatType), FloatType)
-        :: Literal("b")
-        :: Nil))
-    // type coercion for map values
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      CreateMap(Literal(1)
-        :: Literal("a")
-        :: Literal(2)
-        :: Literal(3.0)
-        :: Nil),
-      CreateMap(Literal(1)
-        :: Cast(Literal("a"), StringType)
-        :: Literal(2)
-        :: Cast(Literal(3.0), StringType)
-        :: Nil))
-    // type coercion for both map keys and values
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      CreateMap(Literal(1)
-        :: Literal("a")
-        :: Literal(2.0)
-        :: Literal(3.0)
-        :: Nil),
-      CreateMap(Cast(Literal(1), DoubleType)
-        :: Cast(Literal("a"), StringType)
-        :: Cast(Literal(2.0), DoubleType)
-        :: Cast(Literal(3.0), StringType)
-        :: Nil))
-  }
-
-  test("greatest/least cast") {
-    for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
-      ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-        operator(Literal(1.0)
-          :: Literal(1)
-          :: Literal.create(1.0, FloatType)
-          :: Nil),
-        operator(Cast(Literal(1.0), DoubleType)
-          :: Cast(Literal(1), DoubleType)
-          :: Cast(Literal.create(1.0, FloatType), DoubleType)
-          :: Nil))
-      ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-        operator(Literal(1L)
-          :: Literal(1)
-          :: Literal(new java.math.BigDecimal("1000000000000000000000"))
-          :: Nil),
-        operator(Cast(Literal(1L), DecimalType(22, 0))
-          :: Cast(Literal(1), DecimalType(22, 0))
-          :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
-          :: Nil))
-    }
-  }
-
-  test("nanvl casts") {
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
-      NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
-      NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
-    ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
-      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
-      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
-  }
-
-  test("type coercion for If") {
-    val rule = HiveTypeCoercion.IfCoercion
-
-    ruleTest(rule,
-      If(Literal(true), Literal(1), Literal(1L)),
-      If(Literal(true), Cast(Literal(1), LongType), Literal(1L)))
-
-    ruleTest(rule,
-      If(Literal.create(null, NullType), Literal(1), Literal(1)),
-      If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
-
-    ruleTest(rule,
-      If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)),
-      If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2)))
-
-    ruleTest(rule,
-      If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)),
-      If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2)))
-  }
-
-  test("type coercion for CaseKeyWhen") {
-    ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
-      CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
-      CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
-    )
-    ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
-      CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
-    )
-    ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
-      CaseWhen(Seq((Literal(true), Literal(1.2))),
-        Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
-    )
-    ruleTest(HiveTypeCoercion.CaseWhenCoercion,
-      CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
-      CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
-        Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
-    )
-  }
-
-  test("BooleanEquality type cast") {
-    val be = HiveTypeCoercion.BooleanEquality
-    // Use something more than a literal to avoid triggering the simplification rules.
-    val one = Add(Literal(Decimal(1)), Literal(Decimal(0)))
-
-    ruleTest(be,
-      EqualTo(Literal(true), one),
-      EqualTo(Cast(Literal(true), one.dataType), one)
-    )
-
-    ruleTest(be,
-      EqualTo(one, Literal(true)),
-      EqualTo(one, Cast(Literal(true), one.dataType))
-    )
-
-    ruleTest(be,
-      EqualNullSafe(Literal(true), one),
-      EqualNullSafe(Cast(Literal(true), one.dataType), one)
-    )
-
-    ruleTest(be,
-      EqualNullSafe(one, Literal(true)),
-      EqualNullSafe(one, Cast(Literal(true), one.dataType))
-    )
-  }
-
-  test("BooleanEquality simplification") {
-    val be = HiveTypeCoercion.BooleanEquality
-
-    ruleTest(be,
-      EqualTo(Literal(true), Literal(1)),
-      Literal(true)
-    )
-    ruleTest(be,
-      EqualTo(Literal(true), Literal(0)),
-      Not(Literal(true))
-    )
-    ruleTest(be,
-      EqualNullSafe(Literal(true), Literal(1)),
-      And(IsNotNull(Literal(true)), Literal(true))
-    )
-    ruleTest(be,
-      EqualNullSafe(Literal(true), Literal(0)),
-      And(IsNotNull(Literal(true)), Not(Literal(true)))
-    )
-
-    ruleTest(be,
-      EqualTo(Literal(true), Literal(1L)),
-      Literal(true)
-    )
-    ruleTest(be,
-      EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
-      Literal(true)
-    )
-    ruleTest(be,
-      EqualTo(Literal(BigDecimal(0)), Literal(true)),
-      Not(Literal(true))
-    )
-    ruleTest(be,
-      EqualTo(Literal(Decimal(1)), Literal(true)),
-      Literal(true)
-    )
-    ruleTest(be,
-      EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
-      Literal(true)
-    )
-  }
-
-  private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
-    logical.output.zip(expectTypes).foreach { case (attr, dt) =>
-      assert(attr.dataType === dt)
-    }
-  }
-
-  test("WidenSetOperationTypes for except and intersect") {
-    val firstTable = LocalRelation(
-      AttributeReference("i", IntegerType)(),
-      AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
-      AttributeReference("b", ByteType)(),
-      AttributeReference("d", DoubleType)())
-    val secondTable = LocalRelation(
-      AttributeReference("s", StringType)(),
-      AttributeReference("d", DecimalType(2, 1))(),
-      AttributeReference("f", FloatType)(),
-      AttributeReference("l", LongType)())
-
-    val wt = HiveTypeCoercion.WidenSetOperationTypes
-    val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
-
-    val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
-    val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
-    checkOutput(r1.left, expectedTypes)
-    checkOutput(r1.right, expectedTypes)
-    checkOutput(r2.left, expectedTypes)
-    checkOutput(r2.right, expectedTypes)
-
-    // Check if a Project is added
-    assert(r1.left.isInstanceOf[Project])
-    assert(r1.right.isInstanceOf[Project])
-    assert(r2.left.isInstanceOf[Project])
-    assert(r2.right.isInstanceOf[Project])
-  }
-
-  test("WidenSetOperationTypes for union") {
-    val firstTable = LocalRelation(
-      AttributeReference("i", IntegerType)(),
-      AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
-      AttributeReference("b", ByteType)(),
-      AttributeReference("d", DoubleType)())
-    val secondTable = LocalRelation(
-      AttributeReference("s", StringType)(),
-      AttributeReference("d", DecimalType(2, 1))(),
-      AttributeReference("f", FloatType)(),
-      AttributeReference("l", LongType)())
-    val thirdTable = LocalRelation(
-      AttributeReference("m", StringType)(),
-      AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
-      AttributeReference("p", FloatType)(),
-      AttributeReference("q", DoubleType)())
-    val forthTable = LocalRelation(
-      AttributeReference("m", StringType)(),
-      AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
-      AttributeReference("p", ByteType)(),
-      AttributeReference("q", DoubleType)())
-
-    val wt = HiveTypeCoercion.WidenSetOperationTypes
-    val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
-
-    val unionRelation = wt(
-      Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
-    assert(unionRelation.children.length == 4)
-    checkOutput(unionRelation.children.head, expectedTypes)
-    checkOutput(unionRelation.children(1), expectedTypes)
-    checkOutput(unionRelation.children(2), expectedTypes)
-    checkOutput(unionRelation.children(3), expectedTypes)
-
-    assert(unionRelation.children.head.isInstanceOf[Project])
-    assert(unionRelation.children(1).isInstanceOf[Project])
-    assert(unionRelation.children(2).isInstanceOf[Project])
-    assert(unionRelation.children(3).isInstanceOf[Project])
-  }
-
-  test("Transform Decimal precision/scale for union except and intersect") {
-    def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
-      logical.output.zip(expectTypes).foreach { case (attr, dt) =>
-        assert(attr.dataType === dt)
-      }
-    }
-
-    val dp = HiveTypeCoercion.WidenSetOperationTypes
-
-    val left1 = LocalRelation(
-      AttributeReference("l", DecimalType(10, 8))())
-    val right1 = LocalRelation(
-      AttributeReference("r", DecimalType(5, 5))())
-    val expectedType1 = Seq(DecimalType(10, 8))
-
-    val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
-    val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
-    val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
-
-    checkOutput(r1.children.head, expectedType1)
-    checkOutput(r1.children.last, expectedType1)
-    checkOutput(r2.left, expectedType1)
-    checkOutput(r2.right, expectedType1)
-    checkOutput(r3.left, expectedType1)
-    checkOutput(r3.right, expectedType1)
-
-    val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))())
-
-    val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
-    val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
-      DecimalType(25, 5), DoubleType, DoubleType)
-
-    rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
-      val plan2 = LocalRelation(
-        AttributeReference("r", rType)())
-
-      val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
-      val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
-      val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
-
-      checkOutput(r1.children.last, Seq(expectedType))
-      checkOutput(r2.right, Seq(expectedType))
-      checkOutput(r3.right, Seq(expectedType))
-
-      val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
-      val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
-      val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
-
-      checkOutput(r4.children.last, Seq(expectedType))
-      checkOutput(r5.left, Seq(expectedType))
-      checkOutput(r6.left, Seq(expectedType))
-    }
-  }
-
-  test("rule for date/timestamp operations") {
-    val dateTimeOperations = HiveTypeCoercion.DateTimeOperations
-    val date = Literal(new java.sql.Date(0L))
-    val timestamp = Literal(new Timestamp(0L))
-    val interval = Literal(new CalendarInterval(0, 0))
-    val str = Literal("2015-01-01")
-
-    ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType))
-    ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType))
-    ruleTest(dateTimeOperations, Add(timestamp, interval),
-      Cast(TimeAdd(timestamp, interval), TimestampType))
-    ruleTest(dateTimeOperations, Add(interval, timestamp),
-      Cast(TimeAdd(timestamp, interval), TimestampType))
-    ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType))
-    ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType))
-
-    ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType))
-    ruleTest(dateTimeOperations, Subtract(timestamp, interval),
-      Cast(TimeSub(timestamp, interval), TimestampType))
-    ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType))
-
-    // interval operations should not be effected
-    ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval))
-    ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval))
-  }
-
-  /**
-   * There are rules that need to not fire before child expressions get resolved.
-   * We use this test to make sure those rules do not fire early.
-   */
-  test("make sure rules do not fire early") {
-    // InConversion
-    val inConversion = HiveTypeCoercion.InConversion
-    ruleTest(inConversion,
-      In(UnresolvedAttribute("a"), Seq(Literal(1))),
-      In(UnresolvedAttribute("a"), Seq(Literal(1)))
-    )
-    ruleTest(inConversion,
-      In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))),
-      In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1)))
-    )
-    ruleTest(inConversion,
-      In(Literal("a"), Seq(Literal(1), Literal("b"))),
-      In(Cast(Literal("a"), StringType),
-        Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
-    )
-  }
-}
-
-
-object HiveTypeCoercionSuite {
-
-  case class AnyTypeUnaryExpression(child: Expression)
-    extends UnaryExpression with ExpectsInputTypes with Unevaluable {
-    override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-    override def dataType: DataType = NullType
-  }
-
-  case class NumericTypeUnaryExpression(child: Expression)
-    extends UnaryExpression with ExpectsInputTypes with Unevaluable {
-    override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
-    override def dataType: DataType = NullType
-  }
-
-  case class AnyTypeBinaryOperator(left: Expression, right: Expression)
-    extends BinaryOperator with Unevaluable {
-    override def dataType: DataType = NullType
-    override def inputType: AbstractDataType = AnyDataType
-    override def symbol: String = "anytype"
-  }
-
-  case class NumericTypeBinaryOperator(left: Expression, right: Expression)
-    extends BinaryOperator with Unevaluable {
-    override def dataType: DataType = NullType
-    override def inputType: AbstractDataType = NumericType
-    override def symbol: String = "numerictype"
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/43570c57/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
new file mode 100644
index 0000000..7435399
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -0,0 +1,663 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+class TypeCoercionSuite extends PlanTest {
+
+  test("eligible implicit type cast") {
+    def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
+      val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
+      assert(got.map(_.dataType) == Option(expected),
+        s"Failed to cast $from to $to")
+    }
+
+    shouldCast(NullType, NullType, NullType)
+    shouldCast(NullType, IntegerType, IntegerType)
+    shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT)
+
+    shouldCast(ByteType, IntegerType, IntegerType)
+    shouldCast(IntegerType, IntegerType, IntegerType)
+    shouldCast(IntegerType, LongType, LongType)
+    shouldCast(IntegerType, DecimalType, DecimalType(10, 0))
+    shouldCast(LongType, IntegerType, IntegerType)
+    shouldCast(LongType, DecimalType, DecimalType(20, 0))
+
+    shouldCast(DateType, TimestampType, TimestampType)
+    shouldCast(TimestampType, DateType, DateType)
+
+    shouldCast(StringType, IntegerType, IntegerType)
+    shouldCast(StringType, DateType, DateType)
+    shouldCast(StringType, TimestampType, TimestampType)
+    shouldCast(IntegerType, StringType, StringType)
+    shouldCast(DateType, StringType, StringType)
+    shouldCast(TimestampType, StringType, StringType)
+
+    shouldCast(StringType, BinaryType, BinaryType)
+    shouldCast(BinaryType, StringType, StringType)
+
+    shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
+
+    shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
+    shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
+    shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
+
+    shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType)
+    shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
+    shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
+    shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
+
+    shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
+    shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
+
+    shouldCast(DecimalType.SYSTEM_DEFAULT,
+      TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT)
+    shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
+    shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
+    shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
+
+    shouldCast(StringType, NumericType, DoubleType)
+    shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType)
+
+    // NumericType should not be changed when function accepts any of them.
+    Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
+      DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe =>
+      shouldCast(tpe, NumericType, tpe)
+    }
+
+    shouldCast(
+      ArrayType(StringType, false),
+      TypeCollection(ArrayType(StringType), StringType),
+      ArrayType(StringType, false))
+
+    shouldCast(
+      ArrayType(StringType, true),
+      TypeCollection(ArrayType(StringType), StringType),
+      ArrayType(StringType, true))
+  }
+
+  test("ineligible implicit type cast") {
+    def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
+      val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
+      assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
+    }
+
+    shouldNotCast(IntegerType, DateType)
+    shouldNotCast(IntegerType, TimestampType)
+    shouldNotCast(LongType, DateType)
+    shouldNotCast(LongType, TimestampType)
+    shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType)
+    shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType)
+
+    shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
+
+    shouldNotCast(IntegerType, ArrayType)
+    shouldNotCast(IntegerType, MapType)
+    shouldNotCast(IntegerType, StructType)
+
+    shouldNotCast(CalendarIntervalType, StringType)
+
+    // Don't implicitly cast complex types to string.
+    shouldNotCast(ArrayType(StringType), StringType)
+    shouldNotCast(MapType(StringType, StringType), StringType)
+    shouldNotCast(new StructType().add("a1", StringType), StringType)
+    shouldNotCast(MapType(StringType, StringType), StringType)
+  }
+
+  test("tightest common bound for types") {
+    def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
+      var found = TypeCoercion.findTightestCommonTypeOfTwo(t1, t2)
+      assert(found == tightestCommon,
+        s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
+      // Test both directions to make sure the widening is symmetric.
+      found = TypeCoercion.findTightestCommonTypeOfTwo(t2, t1)
+      assert(found == tightestCommon,
+        s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
+    }
+
+    // Null
+    widenTest(NullType, NullType, Some(NullType))
+
+    // Boolean
+    widenTest(NullType, BooleanType, Some(BooleanType))
+    widenTest(BooleanType, BooleanType, Some(BooleanType))
+    widenTest(IntegerType, BooleanType, None)
+    widenTest(LongType, BooleanType, None)
+
+    // Integral
+    widenTest(NullType, ByteType, Some(ByteType))
+    widenTest(NullType, IntegerType, Some(IntegerType))
+    widenTest(NullType, LongType, Some(LongType))
+    widenTest(ShortType, IntegerType, Some(IntegerType))
+    widenTest(ShortType, LongType, Some(LongType))
+    widenTest(IntegerType, LongType, Some(LongType))
+    widenTest(LongType, LongType, Some(LongType))
+
+    // Floating point
+    widenTest(NullType, FloatType, Some(FloatType))
+    widenTest(NullType, DoubleType, Some(DoubleType))
+    widenTest(FloatType, DoubleType, Some(DoubleType))
+    widenTest(FloatType, FloatType, Some(FloatType))
+    widenTest(DoubleType, DoubleType, Some(DoubleType))
+
+    // Integral mixed with floating point.
+    widenTest(IntegerType, FloatType, Some(FloatType))
+    widenTest(IntegerType, DoubleType, Some(DoubleType))
+    widenTest(IntegerType, DoubleType, Some(DoubleType))
+    widenTest(LongType, FloatType, Some(FloatType))
+    widenTest(LongType, DoubleType, Some(DoubleType))
+
+    // No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
+    widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
+    widenTest(DecimalType(2, 1), DoubleType, None)
+    widenTest(DecimalType(2, 1), IntegerType, None)
+    widenTest(DoubleType, DecimalType(2, 1), None)
+    widenTest(IntegerType, DecimalType(2, 1), None)
+
+    // StringType
+    widenTest(NullType, StringType, Some(StringType))
+    widenTest(StringType, StringType, Some(StringType))
+    widenTest(IntegerType, StringType, None)
+    widenTest(LongType, StringType, None)
+
+    // TimestampType
+    widenTest(NullType, TimestampType, Some(TimestampType))
+    widenTest(TimestampType, TimestampType, Some(TimestampType))
+    widenTest(IntegerType, TimestampType, None)
+    widenTest(StringType, TimestampType, None)
+
+    // ComplexType
+    widenTest(NullType,
+      MapType(IntegerType, StringType, false),
+      Some(MapType(IntegerType, StringType, false)))
+    widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
+    widenTest(StringType, MapType(IntegerType, StringType, true), None)
+    widenTest(ArrayType(IntegerType), StructType(Seq()), None)
+  }
+
+  private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
+    val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+    comparePlans(
+      rule(Project(Seq(Alias(initial, "a")()), testRelation)),
+      Project(Seq(Alias(transformed, "a")()), testRelation))
+  }
+
+  test("cast NullType for expressions that implement ExpectsInputTypes") {
+    import TypeCoercionSuite._
+
+    ruleTest(TypeCoercion.ImplicitTypeCasts,
+      AnyTypeUnaryExpression(Literal.create(null, NullType)),
+      AnyTypeUnaryExpression(Literal.create(null, NullType)))
+
+    ruleTest(TypeCoercion.ImplicitTypeCasts,
+      NumericTypeUnaryExpression(Literal.create(null, NullType)),
+      NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
+  }
+
+  test("cast NullType for binary operators") {
+    import TypeCoercionSuite._
+
+    ruleTest(TypeCoercion.ImplicitTypeCasts,
+      AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
+      AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
+
+    ruleTest(TypeCoercion.ImplicitTypeCasts,
+      NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
+      NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
+  }
+
+  test("coalesce casts") {
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      Coalesce(Literal(1.0)
+        :: Literal(1)
+        :: Literal.create(1.0, FloatType)
+        :: Nil),
+      Coalesce(Cast(Literal(1.0), DoubleType)
+        :: Cast(Literal(1), DoubleType)
+        :: Cast(Literal.create(1.0, FloatType), DoubleType)
+        :: Nil))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      Coalesce(Literal(1L)
+        :: Literal(1)
+        :: Literal(new java.math.BigDecimal("1000000000000000000000"))
+        :: Nil),
+      Coalesce(Cast(Literal(1L), DecimalType(22, 0))
+        :: Cast(Literal(1), DecimalType(22, 0))
+        :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
+        :: Nil))
+  }
+
+  test("CreateArray casts") {
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateArray(Literal(1.0)
+        :: Literal(1)
+        :: Literal.create(1.0, FloatType)
+        :: Nil),
+      CreateArray(Cast(Literal(1.0), DoubleType)
+        :: Cast(Literal(1), DoubleType)
+        :: Cast(Literal.create(1.0, FloatType), DoubleType)
+        :: Nil))
+
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateArray(Literal(1.0)
+        :: Literal(1)
+        :: Literal("a")
+        :: Nil),
+      CreateArray(Cast(Literal(1.0), StringType)
+        :: Cast(Literal(1), StringType)
+        :: Cast(Literal("a"), StringType)
+        :: Nil))
+  }
+
+  test("CreateMap casts") {
+    // type coercion for map keys
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal("a")
+        :: Literal.create(2.0, FloatType)
+        :: Literal("b")
+        :: Nil),
+      CreateMap(Cast(Literal(1), FloatType)
+        :: Literal("a")
+        :: Cast(Literal.create(2.0, FloatType), FloatType)
+        :: Literal("b")
+        :: Nil))
+    // type coercion for map values
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal("a")
+        :: Literal(2)
+        :: Literal(3.0)
+        :: Nil),
+      CreateMap(Literal(1)
+        :: Cast(Literal("a"), StringType)
+        :: Literal(2)
+        :: Cast(Literal(3.0), StringType)
+        :: Nil))
+    // type coercion for both map keys and values
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      CreateMap(Literal(1)
+        :: Literal("a")
+        :: Literal(2.0)
+        :: Literal(3.0)
+        :: Nil),
+      CreateMap(Cast(Literal(1), DoubleType)
+        :: Cast(Literal("a"), StringType)
+        :: Cast(Literal(2.0), DoubleType)
+        :: Cast(Literal(3.0), StringType)
+        :: Nil))
+  }
+
+  test("greatest/least cast") {
+    for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
+      ruleTest(TypeCoercion.FunctionArgumentConversion,
+        operator(Literal(1.0)
+          :: Literal(1)
+          :: Literal.create(1.0, FloatType)
+          :: Nil),
+        operator(Cast(Literal(1.0), DoubleType)
+          :: Cast(Literal(1), DoubleType)
+          :: Cast(Literal.create(1.0, FloatType), DoubleType)
+          :: Nil))
+      ruleTest(TypeCoercion.FunctionArgumentConversion,
+        operator(Literal(1L)
+          :: Literal(1)
+          :: Literal(new java.math.BigDecimal("1000000000000000000000"))
+          :: Nil),
+        operator(Cast(Literal(1L), DecimalType(22, 0))
+          :: Cast(Literal(1), DecimalType(22, 0))
+          :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
+          :: Nil))
+    }
+  }
+
+  test("nanvl casts") {
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
+      NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
+      NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
+    ruleTest(TypeCoercion.FunctionArgumentConversion,
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
+      NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
+  }
+
+  test("type coercion for If") {
+    val rule = TypeCoercion.IfCoercion
+
+    ruleTest(rule,
+      If(Literal(true), Literal(1), Literal(1L)),
+      If(Literal(true), Cast(Literal(1), LongType), Literal(1L)))
+
+    ruleTest(rule,
+      If(Literal.create(null, NullType), Literal(1), Literal(1)),
+      If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
+
+    ruleTest(rule,
+      If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)),
+      If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2)))
+
+    ruleTest(rule,
+      If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)),
+      If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2)))
+  }
+
+  test("type coercion for CaseKeyWhen") {
+    ruleTest(TypeCoercion.ImplicitTypeCasts,
+      CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
+      CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
+    )
+    ruleTest(TypeCoercion.CaseWhenCoercion,
+      CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
+      CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
+    )
+    ruleTest(TypeCoercion.CaseWhenCoercion,
+      CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
+      CaseWhen(Seq((Literal(true), Literal(1.2))),
+        Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
+    )
+    ruleTest(TypeCoercion.CaseWhenCoercion,
+      CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
+      CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
+        Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
+    )
+  }
+
+  test("BooleanEquality type cast") {
+    val be = TypeCoercion.BooleanEquality
+    // Use something more than a literal to avoid triggering the simplification rules.
+    val one = Add(Literal(Decimal(1)), Literal(Decimal(0)))
+
+    ruleTest(be,
+      EqualTo(Literal(true), one),
+      EqualTo(Cast(Literal(true), one.dataType), one)
+    )
+
+    ruleTest(be,
+      EqualTo(one, Literal(true)),
+      EqualTo(one, Cast(Literal(true), one.dataType))
+    )
+
+    ruleTest(be,
+      EqualNullSafe(Literal(true), one),
+      EqualNullSafe(Cast(Literal(true), one.dataType), one)
+    )
+
+    ruleTest(be,
+      EqualNullSafe(one, Literal(true)),
+      EqualNullSafe(one, Cast(Literal(true), one.dataType))
+    )
+  }
+
+  test("BooleanEquality simplification") {
+    val be = TypeCoercion.BooleanEquality
+
+    ruleTest(be,
+      EqualTo(Literal(true), Literal(1)),
+      Literal(true)
+    )
+    ruleTest(be,
+      EqualTo(Literal(true), Literal(0)),
+      Not(Literal(true))
+    )
+    ruleTest(be,
+      EqualNullSafe(Literal(true), Literal(1)),
+      And(IsNotNull(Literal(true)), Literal(true))
+    )
+    ruleTest(be,
+      EqualNullSafe(Literal(true), Literal(0)),
+      And(IsNotNull(Literal(true)), Not(Literal(true)))
+    )
+
+    ruleTest(be,
+      EqualTo(Literal(true), Literal(1L)),
+      Literal(true)
+    )
+    ruleTest(be,
+      EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
+      Literal(true)
+    )
+    ruleTest(be,
+      EqualTo(Literal(BigDecimal(0)), Literal(true)),
+      Not(Literal(true))
+    )
+    ruleTest(be,
+      EqualTo(Literal(Decimal(1)), Literal(true)),
+      Literal(true)
+    )
+    ruleTest(be,
+      EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
+      Literal(true)
+    )
+  }
+
+  private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+    logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+      assert(attr.dataType === dt)
+    }
+  }
+
+  test("WidenSetOperationTypes for except and intersect") {
+    val firstTable = LocalRelation(
+      AttributeReference("i", IntegerType)(),
+      AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+      AttributeReference("b", ByteType)(),
+      AttributeReference("d", DoubleType)())
+    val secondTable = LocalRelation(
+      AttributeReference("s", StringType)(),
+      AttributeReference("d", DecimalType(2, 1))(),
+      AttributeReference("f", FloatType)(),
+      AttributeReference("l", LongType)())
+
+    val wt = TypeCoercion.WidenSetOperationTypes
+    val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
+
+    val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except]
+    val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect]
+    checkOutput(r1.left, expectedTypes)
+    checkOutput(r1.right, expectedTypes)
+    checkOutput(r2.left, expectedTypes)
+    checkOutput(r2.right, expectedTypes)
+
+    // Check if a Project is added
+    assert(r1.left.isInstanceOf[Project])
+    assert(r1.right.isInstanceOf[Project])
+    assert(r2.left.isInstanceOf[Project])
+    assert(r2.right.isInstanceOf[Project])
+  }
+
+  test("WidenSetOperationTypes for union") {
+    val firstTable = LocalRelation(
+      AttributeReference("i", IntegerType)(),
+      AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
+      AttributeReference("b", ByteType)(),
+      AttributeReference("d", DoubleType)())
+    val secondTable = LocalRelation(
+      AttributeReference("s", StringType)(),
+      AttributeReference("d", DecimalType(2, 1))(),
+      AttributeReference("f", FloatType)(),
+      AttributeReference("l", LongType)())
+    val thirdTable = LocalRelation(
+      AttributeReference("m", StringType)(),
+      AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
+      AttributeReference("p", FloatType)(),
+      AttributeReference("q", DoubleType)())
+    val forthTable = LocalRelation(
+      AttributeReference("m", StringType)(),
+      AttributeReference("n", DecimalType.SYSTEM_DEFAULT)(),
+      AttributeReference("p", ByteType)(),
+      AttributeReference("q", DoubleType)())
+
+    val wt = TypeCoercion.WidenSetOperationTypes
+    val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
+
+    val unionRelation = wt(
+      Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union]
+    assert(unionRelation.children.length == 4)
+    checkOutput(unionRelation.children.head, expectedTypes)
+    checkOutput(unionRelation.children(1), expectedTypes)
+    checkOutput(unionRelation.children(2), expectedTypes)
+    checkOutput(unionRelation.children(3), expectedTypes)
+
+    assert(unionRelation.children.head.isInstanceOf[Project])
+    assert(unionRelation.children(1).isInstanceOf[Project])
+    assert(unionRelation.children(2).isInstanceOf[Project])
+    assert(unionRelation.children(3).isInstanceOf[Project])
+  }
+
+  test("Transform Decimal precision/scale for union except and intersect") {
+    def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
+      logical.output.zip(expectTypes).foreach { case (attr, dt) =>
+        assert(attr.dataType === dt)
+      }
+    }
+
+    val dp = TypeCoercion.WidenSetOperationTypes
+
+    val left1 = LocalRelation(
+      AttributeReference("l", DecimalType(10, 8))())
+    val right1 = LocalRelation(
+      AttributeReference("r", DecimalType(5, 5))())
+    val expectedType1 = Seq(DecimalType(10, 8))
+
+    val r1 = dp(Union(left1, right1)).asInstanceOf[Union]
+    val r2 = dp(Except(left1, right1)).asInstanceOf[Except]
+    val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect]
+
+    checkOutput(r1.children.head, expectedType1)
+    checkOutput(r1.children.last, expectedType1)
+    checkOutput(r2.left, expectedType1)
+    checkOutput(r2.right, expectedType1)
+    checkOutput(r3.left, expectedType1)
+    checkOutput(r3.right, expectedType1)
+
+    val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))())
+
+    val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
+    val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
+      DecimalType(25, 5), DoubleType, DoubleType)
+
+    rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
+      val plan2 = LocalRelation(
+        AttributeReference("r", rType)())
+
+      val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union]
+      val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except]
+      val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect]
+
+      checkOutput(r1.children.last, Seq(expectedType))
+      checkOutput(r2.right, Seq(expectedType))
+      checkOutput(r3.right, Seq(expectedType))
+
+      val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union]
+      val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except]
+      val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect]
+
+      checkOutput(r4.children.last, Seq(expectedType))
+      checkOutput(r5.left, Seq(expectedType))
+      checkOutput(r6.left, Seq(expectedType))
+    }
+  }
+
+  test("rule for date/timestamp operations") {
+    val dateTimeOperations = TypeCoercion.DateTimeOperations
+    val date = Literal(new java.sql.Date(0L))
+    val timestamp = Literal(new Timestamp(0L))
+    val interval = Literal(new CalendarInterval(0, 0))
+    val str = Literal("2015-01-01")
+
+    ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType))
+    ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType))
+    ruleTest(dateTimeOperations, Add(timestamp, interval),
+      Cast(TimeAdd(timestamp, interval), TimestampType))
+    ruleTest(dateTimeOperations, Add(interval, timestamp),
+      Cast(TimeAdd(timestamp, interval), TimestampType))
+    ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType))
+    ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType))
+
+    ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType))
+    ruleTest(dateTimeOperations, Subtract(timestamp, interval),
+      Cast(TimeSub(timestamp, interval), TimestampType))
+    ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType))
+
+    // interval operations should not be effected
+    ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval))
+    ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval))
+  }
+
+  /**
+   * There are rules that need to not fire before child expressions get resolved.
+   * We use this test to make sure those rules do not fire early.
+   */
+  test("make sure rules do not fire early") {
+    // InConversion
+    val inConversion = TypeCoercion.InConversion
+    ruleTest(inConversion,
+      In(UnresolvedAttribute("a"), Seq(Literal(1))),
+      In(UnresolvedAttribute("a"), Seq(Literal(1)))
+    )
+    ruleTest(inConversion,
+      In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1))),
+      In(Literal("test"), Seq(UnresolvedAttribute("a"), Literal(1)))
+    )
+    ruleTest(inConversion,
+      In(Literal("a"), Seq(Literal(1), Literal("b"))),
+      In(Cast(Literal("a"), StringType),
+        Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
+    )
+  }
+}
+
+
+object TypeCoercionSuite {
+
+  case class AnyTypeUnaryExpression(child: Expression)
+    extends UnaryExpression with ExpectsInputTypes with Unevaluable {
+    override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+    override def dataType: DataType = NullType
+  }
+
+  case class NumericTypeUnaryExpression(child: Expression)
+    extends UnaryExpression with ExpectsInputTypes with Unevaluable {
+    override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+    override def dataType: DataType = NullType
+  }
+
+  case class AnyTypeBinaryOperator(left: Expression, right: Expression)
+    extends BinaryOperator with Unevaluable {
+    override def dataType: DataType = NullType
+    override def inputType: AbstractDataType = AnyDataType
+    override def symbol: String = "anytype"
+  }
+
+  case class NumericTypeBinaryOperator(left: Expression, right: Expression)
+    extends BinaryOperator with Unevaluable {
+    override def dataType: DataType = NullType
+    override def inputType: AbstractDataType = NumericType
+    override def symbol: String = "numerictype"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/43570c57/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 05c8d8e..de3d889 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -25,7 +25,7 @@ import scala.util.control.Exception._
 import scala.util.Try
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -171,11 +171,11 @@ private[csv] object CSVInferSchema {
     StringType
   }
 
-  private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence
+  private val numericPrecedence: IndexedSeq[DataType] = TypeCoercion.numericPrecedence
 
   /**
    * Copied from internal Spark api
-   * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
+   * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]]
    */
   val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
     case (t1, t2) if t1 == t2 => Some(t1)

http://git-wip-us.apache.org/repos/asf/spark/blob/43570c57/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 42c8262..579b036 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -22,7 +22,7 @@ import java.util.Comparator
 import com.fasterxml.jackson.core._
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
 import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -252,7 +252,7 @@ private[sql] object InferSchema {
    * Returns the most general data type for two given data types.
    */
   def compatibleType(t1: DataType, t2: DataType): DataType = {
-    HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
+    TypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
       // t1 or t2 is a StructType, ArrayType, or an unexpected type.
       (t1, t2) match {
         // Double support larger range than fixed decimal, DecimalType.Maximum should be enough


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message