Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 0ECF819693 for ; Fri, 25 Mar 2016 19:08:01 +0000 (UTC) Received: (qmail 45569 invoked by uid 500); 25 Mar 2016 19:08:01 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 45538 invoked by uid 500); 25 Mar 2016 19:08:01 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 45527 invoked by uid 99); 25 Mar 2016 19:08:00 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 25 Mar 2016 19:08:00 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id BC8BADFC74; Fri, 25 Mar 2016 19:08:00 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: marmbrus@apache.org To: commits@spark.apache.org Message-Id: <9a62de00a20b4ecfbc3fdd3062c3df4d@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-12443][SQL] encoderFor should support Decimal Date: Fri, 25 Mar 2016 19:08:00 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 11fa8741c -> ca003354d [SPARK-12443][SQL] encoderFor should support Decimal ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-12443 `constructorFor` will call `dataTypeFor` to determine if a type is `ObjectType` or not. If there is not case for `Decimal`, it will be recognized as `ObjectType` and causes the bug. ## How was this patch tested? Test is added into `ExpressionEncoderSuite`. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #10399 from viirya/fix-encoder-decimal. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca003354 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca003354 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca003354 Branch: refs/heads/master Commit: ca003354da5e738e97418efc5af07be071c16d8f Parents: 11fa874 Author: Liang-Chi Hsieh Authored: Fri Mar 25 12:07:56 2016 -0700 Committer: Michael Armbrust Committed: Fri Mar 25 12:07:56 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/ScalaReflection.scala | 1 + .../sql/catalyst/encoders/RowEncoder.scala | 21 +++++++++++++++++--- .../org/apache/spark/sql/types/Decimal.scala | 8 ++++++++ .../encoders/ExpressionEncoderSuite.scala | 4 +++- .../sql/catalyst/encoders/RowEncoderSuite.scala | 17 ++++++++++++++++ 5 files changed, 47 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5e1672c..f208401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) className match { http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 902644e..30f56d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -79,7 +79,7 @@ object RowEncoder { StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, - "apply", + "fromDecimal", inputObject :: Nil) case StringType => @@ -95,7 +95,7 @@ object RowEncoder { classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et)) + case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et)) } case t @ MapType(kt, vt, valueNullable) => @@ -129,7 +129,7 @@ object RowEncoder { Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), extractorsFor( - Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), + Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), f.dataType)) } If(IsNull(inputObject), @@ -137,6 +137,21 @@ object RowEncoder { CreateStruct(convertedFields)) } + /** + * Returns the `DataType` that can be used when generating code that converts input data + * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the `DataType` returned + * by this function can be more permissive since multiple external types may map to a single + * internal type. For example, for an input with DecimalType in external row, its external types + * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or + * `org.apache.spark.sql.types.Decimal`. + */ + private def externalDataTypeForInput(dt: DataType): DataType = dt match { + // In order to support both Decimal and java BigDecimal in external row, we make this + // as java.lang.Object. + case _: DecimalType => ObjectType(classOf[java.lang.Object]) + case _ => externalDataTypeFor(dt) + } + private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt case CalendarIntervalType => dt http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f0e535b..a30a392 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -376,6 +376,14 @@ object Decimal { def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + // This is used for RowEncoder to handle Decimal inside external row. + def fromDecimal(value: Any): Decimal = { + value match { + case j: java.math.BigDecimal => apply(j) + case d: Decimal => d + } + } + /** * Creates a decimal from unscaled, precision and scale without checking the bounds. */ http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 3024858..f6583bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType} +import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -101,6 +101,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") + encodeDecodeTest("hello", "string") encodeDecodeTest(Date.valueOf("2012-12-23"), "date") encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") http://git-wip-us.apache.org/repos/asf/spark/blob/ca003354/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index bf0360c..a8fa372 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -143,6 +143,23 @@ class RowEncoderSuite extends SparkFunSuite { assert(input.getStruct(0) == convertedBack.getStruct(0)) } + test("encode/decode Decimal") { + val schema = new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + + val encoder = RowEncoder(schema) + + val input: Row = Row(100, "test", 0.123, Decimal(1234.5678)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + // Decimal inside external row will be converted back to Java BigDecimal when decoding. + assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal + .compareTo(convertedBack.getDecimal(3)) == 0) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org