spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-11827][SQL] Adding java.math.BigInteger support in Java type inference for POJOs and Java collections
Date Fri, 20 May 2016 04:41:20 GMT
Repository: spark
Updated Branches:
  refs/heads/master d5c47f8ff -> 17591d90e


[SPARK-11827][SQL] Adding java.math.BigInteger support in Java type inference for POJOs and
Java collections

Hello : Can you help check this PR? I am adding support for the java.math.BigInteger for java
bean code path. I saw internally spark is converting the BigInteger to BigDecimal in ColumnType.scala
and CatalystRowConverter.scala. I use the similar way and convert the BigInteger to the BigDecimal.
.

Author: Kevin Yu <qyu@us.ibm.com>

Closes #10125 from kevinyu98/working_on_spark-11827.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/17591d90
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/17591d90
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/17591d90

Branch: refs/heads/master
Commit: 17591d90e6873f30a042112f56a1686726ccbd60
Parents: d5c47f8
Author: Kevin Yu <qyu@us.ibm.com>
Authored: Fri May 20 12:41:14 2016 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri May 20 12:41:14 2016 +0800

----------------------------------------------------------------------
 .../sql/catalyst/CatalystTypeConverters.scala   |  2 ++
 .../spark/sql/catalyst/JavaTypeInference.scala  |  1 +
 .../spark/sql/catalyst/ScalaReflection.scala    | 24 ++++++++++++++++
 .../org/apache/spark/sql/types/Decimal.scala    | 29 +++++++++++++++++++-
 .../apache/spark/sql/types/DecimalType.scala    |  1 +
 .../encoders/ExpressionEncoderSuite.scala       |  4 ++-
 .../apache/spark/sql/JavaDataFrameSuite.java    | 11 +++++++-
 .../sql/ScalaReflectionRelationSuite.scala      | 10 +++++--
 8 files changed, 76 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 9bfc381..9cc7b2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst
 
 import java.lang.{Iterable => JavaIterable}
 import java.math.{BigDecimal => JavaBigDecimal}
+import java.math.{BigInteger => JavaBigInteger}
 import java.sql.{Date, Timestamp}
 import java.util.{Map => JavaMap}
 import javax.annotation.Nullable
@@ -326,6 +327,7 @@ object CatalystTypeConverters {
       val decimal = scalaValue match {
         case d: BigDecimal => Decimal(d)
         case d: JavaBigDecimal => Decimal(d)
+        case d: JavaBigInteger => Decimal(d)
         case d: Decimal => d
       }
       if (decimal.changePrecision(dataType.precision, dataType.scale)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 6907582..1fe1434 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -89,6 +89,7 @@ object JavaTypeInference {
       case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
 
       case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT,
true)
+      case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal,
true)
       case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
       case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c0fa220..58df651 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -259,6 +259,12 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< localTypeOf[BigDecimal] =>
         Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
 
+      case t if t <:< localTypeOf[java.math.BigInteger] =>
+        Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
+
+      case t if t <:< localTypeOf[scala.math.BigInt] =>
+        Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
+
       case t if t <:< localTypeOf[Array[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
 
@@ -592,6 +598,20 @@ object ScalaReflection extends ScalaReflection {
             "apply",
             inputObject :: Nil)
 
+        case t if t <:< localTypeOf[java.math.BigInteger] =>
+          StaticInvoke(
+            Decimal.getClass,
+            DecimalType.BigIntDecimal,
+            "apply",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[scala.math.BigInt] =>
+          StaticInvoke(
+            Decimal.getClass,
+            DecimalType.BigIntDecimal,
+            "apply",
+            inputObject :: Nil)
+
         case t if t <:< localTypeOf[java.lang.Integer] =>
           Invoke(inputObject, "intValue", IntegerType)
         case t if t <:< localTypeOf[java.lang.Long] =>
@@ -736,6 +756,10 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT,
nullable = true)
       case t if t <:< localTypeOf[java.math.BigDecimal] =>
         Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+      case t if t <:< localTypeOf[java.math.BigInteger] =>
+        Schema(DecimalType.BigIntDecimal, nullable = true)
+      case t if t <:< localTypeOf[scala.math.BigInt] =>
+        Schema(DecimalType.BigIntDecimal, nullable = true)
       case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT,
nullable = true)
       case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable
= true)
       case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable =
true)

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 2f7422b..b907f62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.types
 
-import java.math.{MathContext, RoundingMode}
+import java.math.{BigInteger, MathContext, RoundingMode}
 
 import org.apache.spark.annotation.DeveloperApi
 
@@ -129,6 +129,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
   }
 
   /**
+   * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0.
+   */
+  def set(bigintval: BigInteger): Decimal = {
+    try {
+      this.decimalVal = null
+      this.longVal = bigintval.longValueExact()
+      this._precision = DecimalType.MAX_PRECISION
+      this._scale = 0
+      this
+    }
+    catch {
+      case e: ArithmeticException =>
+        throw new IllegalArgumentException(s"BigInteger ${bigintval} too large for decimal")
+     }
+  }
+
+  /**
    * Set this Decimal to the given Decimal value.
    */
   def set(decimal: Decimal): Decimal = {
@@ -155,6 +172,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
     }
   }
 
+  def toScalaBigInt: BigInt = BigInt(toLong)
+
+  def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong)
+
   def toUnscaledLong: Long = {
     if (decimalVal.ne(null)) {
       decimalVal.underlying().unscaledValue().longValue()
@@ -371,6 +392,10 @@ object Decimal {
 
   def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value)
 
+  def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value)
+
+  def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger)
+
   def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
     new Decimal().set(value, precision, scale)
 
@@ -387,6 +412,8 @@ object Decimal {
     value match {
       case j: java.math.BigDecimal => apply(j)
       case d: BigDecimal => apply(d)
+      case k: scala.math.BigInt => apply(k)
+      case l: java.math.BigInteger => apply(l)
       case d: Decimal => d
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 9c1319c..6b7e371 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
   private[sql] val LongDecimal = DecimalType(20, 0)
   private[sql] val FloatDecimal = DecimalType(14, 7)
   private[sql] val DoubleDecimal = DecimalType(30, 15)
+  private[sql] val BigIntDecimal = DecimalType(38, 0)
 
   private[sql] def forType(dataType: DataType): DecimalType = dataType match {
     case ByteType => ByteDecimal

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 227e835..d438789 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
+import java.math.BigInteger
 import java.sql.{Date, Timestamp}
 import java.util.Arrays
 
@@ -109,7 +110,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
 
   encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
   encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
-
+  encodeDecodeTest(BigInt("23134123123"), "scala biginteger")
+  encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger")
   encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
 
   encodeDecodeTest("hello", "string")

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 324ebba..35a9f44 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -21,6 +21,8 @@ import java.io.Serializable;
 import java.net.URISyntaxException;
 import java.net.URL;
 import java.util.*;
+import java.math.BigInteger;
+import java.math.BigDecimal;
 
 import scala.collection.JavaConverters;
 import scala.collection.Seq;
@@ -130,6 +132,7 @@ public class JavaDataFrameSuite {
     private Integer[] b = { 0, 1 };
     private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
     private List<String> d = Arrays.asList("floppy", "disk");
+    private BigInteger e = new BigInteger("1234567");
 
     public double getA() {
       return a;
@@ -146,6 +149,8 @@ public class JavaDataFrameSuite {
     public List<String> getD() {
       return d;
     }
+
+    public BigInteger getE() { return e; }
   }
 
   void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
@@ -163,7 +168,9 @@ public class JavaDataFrameSuite {
     Assert.assertEquals(
       new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
       schema.apply("d"));
-    Row first = df.select("a", "b", "c", "d").first();
+    Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()),
+      schema.apply("e"));
+    Row first = df.select("a", "b", "c", "d", "e").first();
     Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
     // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq
below,
     // verify that it has the expected length, and contains expected elements.
@@ -182,6 +189,8 @@ public class JavaDataFrameSuite {
     for (int i = 0; i < d.length(); i++) {
       Assert.assertEquals(bean.getD().get(i), d.apply(i));
     }
+      // Java.math.BigInteger is equavient to Spark Decimal(38,0)
+    Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/17591d90/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 491bdb3..c9bd05d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -34,7 +34,9 @@ case class ReflectData(
     decimalField: java.math.BigDecimal,
     date: Date,
     timestampField: Timestamp,
-    seqInt: Seq[Int])
+    seqInt: Seq[Int],
+    javaBigInt: java.math.BigInteger,
+    scalaBigInt: scala.math.BigInt)
 
 case class NullReflectData(
     intField: java.lang.Integer,
@@ -77,13 +79,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext
{
 
   test("query case class RDD") {
     val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
-      new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1,
2, 3))
+      new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1,
2, 3),
+      new java.math.BigInteger("1"), scala.math.BigInt(1))
     Seq(data).toDF().createOrReplaceTempView("reflectData")
 
     assert(sql("SELECT * FROM reflectData").collect().head ===
       Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
         new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
-        new Timestamp(12345), Seq(1, 2, 3)))
+        new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1),
+        new java.math.BigDecimal(1)))
   }
 
   test("query case class RDD with nulls") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message