spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable
Date Fri, 18 Dec 2015 22:05:19 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 3b903e44b -> bd33d4ee8


[SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable

Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes
the object passed is not serializable.

For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly
makes `StaticInvoke`.

```
case class TimestampContainer(timestamp: java.sql.Timestamp)
val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis))
val df = rdd.toDF
val ds = df.as[TimestampContainer]
val rdd2 = ds.rdd                                 <----------------- invokes extractorsFor
indirectory
```

I'll add test cases.

Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Author: Michael Armbrust <michael@databricks.com>

Closes #10357 from sarutak/SPARK-12404.

(cherry picked from commit 6eba655259d2bcea27d0147b37d5d1e476e85422)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd33d4ee
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd33d4ee
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd33d4ee

Branch: refs/heads/branch-1.6
Commit: bd33d4ee847973289a58032df35375f03e9f9865
Parents: 3b903e4
Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Authored: Fri Dec 18 14:05:06 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Fri Dec 18 14:05:16 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  | 12 ++---
 .../spark/sql/catalyst/ScalaReflection.scala    | 16 +++---
 .../sql/catalyst/encoders/RowEncoder.scala      | 14 +++---
 .../sql/catalyst/expressions/objects.scala      |  8 ++-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 52 ++++++++++++++++++++
 .../org/apache/spark/sql/DatasetSuite.scala     | 12 +++++
 6 files changed, 88 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd33d4ee/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index c8ee87e..f566d1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -194,7 +194,7 @@ object JavaTypeInference {
 
       case c if c == classOf[java.sql.Date] =>
         StaticInvoke(
-          DateTimeUtils,
+          DateTimeUtils.getClass,
           ObjectType(c),
           "toJavaDate",
           getPath :: Nil,
@@ -202,7 +202,7 @@ object JavaTypeInference {
 
       case c if c == classOf[java.sql.Timestamp] =>
         StaticInvoke(
-          DateTimeUtils,
+          DateTimeUtils.getClass,
           ObjectType(c),
           "toJavaTimestamp",
           getPath :: Nil,
@@ -276,7 +276,7 @@ object JavaTypeInference {
             ObjectType(classOf[Array[Any]]))
 
         StaticInvoke(
-          ArrayBasedMapData,
+          ArrayBasedMapData.getClass,
           ObjectType(classOf[JMap[_, _]]),
           "toJavaMap",
           keyData :: valueData :: Nil)
@@ -341,21 +341,21 @@ object JavaTypeInference {
 
         case c if c == classOf[java.sql.Timestamp] =>
           StaticInvoke(
-            DateTimeUtils,
+            DateTimeUtils.getClass,
             TimestampType,
             "fromJavaTimestamp",
             inputObject :: Nil)
 
         case c if c == classOf[java.sql.Date] =>
           StaticInvoke(
-            DateTimeUtils,
+            DateTimeUtils.getClass,
             DateType,
             "fromJavaDate",
             inputObject :: Nil)
 
         case c if c == classOf[java.math.BigDecimal] =>
           StaticInvoke(
-            Decimal,
+            Decimal.getClass,
             DecimalType.SYSTEM_DEFAULT,
             "apply",
             inputObject :: Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/bd33d4ee/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9b6b5b8..ea98956 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -223,7 +223,7 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[java.sql.Date] =>
         StaticInvoke(
-          DateTimeUtils,
+          DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Date]),
           "toJavaDate",
           getPath :: Nil,
@@ -231,7 +231,7 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[java.sql.Timestamp] =>
         StaticInvoke(
-          DateTimeUtils,
+          DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Timestamp]),
           "toJavaTimestamp",
           getPath :: Nil,
@@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
             ObjectType(classOf[Array[Any]]))
 
         StaticInvoke(
-          scala.collection.mutable.WrappedArray,
+          scala.collection.mutable.WrappedArray.getClass,
           ObjectType(classOf[Seq[_]]),
           "make",
           arrayData :: Nil)
@@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection {
             ObjectType(classOf[Array[Any]]))
 
         StaticInvoke(
-          ArrayBasedMapData,
+          ArrayBasedMapData.getClass,
           ObjectType(classOf[Map[_, _]]),
           "toScalaMap",
           keyData :: valueData :: Nil)
@@ -552,28 +552,28 @@ object ScalaReflection extends ScalaReflection {
 
         case t if t <:< localTypeOf[java.sql.Timestamp] =>
           StaticInvoke(
-            DateTimeUtils,
+            DateTimeUtils.getClass,
             TimestampType,
             "fromJavaTimestamp",
             inputObject :: Nil)
 
         case t if t <:< localTypeOf[java.sql.Date] =>
           StaticInvoke(
-            DateTimeUtils,
+            DateTimeUtils.getClass,
             DateType,
             "fromJavaDate",
             inputObject :: Nil)
 
         case t if t <:< localTypeOf[BigDecimal] =>
           StaticInvoke(
-            Decimal,
+            Decimal.getClass,
             DecimalType.SYSTEM_DEFAULT,
             "apply",
             inputObject :: Nil)
 
         case t if t <:< localTypeOf[java.math.BigDecimal] =>
           StaticInvoke(
-            Decimal,
+            Decimal.getClass,
             DecimalType.SYSTEM_DEFAULT,
             "apply",
             inputObject :: Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/bd33d4ee/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 67518f5..6c1220a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -61,21 +61,21 @@ object RowEncoder {
 
     case TimestampType =>
       StaticInvoke(
-        DateTimeUtils,
+        DateTimeUtils.getClass,
         TimestampType,
         "fromJavaTimestamp",
         inputObject :: Nil)
 
     case DateType =>
       StaticInvoke(
-        DateTimeUtils,
+        DateTimeUtils.getClass,
         DateType,
         "fromJavaDate",
         inputObject :: Nil)
 
     case _: DecimalType =>
       StaticInvoke(
-        Decimal,
+        Decimal.getClass,
         DecimalType.SYSTEM_DEFAULT,
         "apply",
         inputObject :: Nil)
@@ -172,14 +172,14 @@ object RowEncoder {
 
     case TimestampType =>
       StaticInvoke(
-        DateTimeUtils,
+        DateTimeUtils.getClass,
         ObjectType(classOf[java.sql.Timestamp]),
         "toJavaTimestamp",
         input :: Nil)
 
     case DateType =>
       StaticInvoke(
-        DateTimeUtils,
+        DateTimeUtils.getClass,
         ObjectType(classOf[java.sql.Date]),
         "toJavaDate",
         input :: Nil)
@@ -197,7 +197,7 @@ object RowEncoder {
           "array",
           ObjectType(classOf[Array[_]]))
       StaticInvoke(
-        scala.collection.mutable.WrappedArray,
+        scala.collection.mutable.WrappedArray.getClass,
         ObjectType(classOf[Seq[_]]),
         "make",
         arrayData :: Nil)
@@ -210,7 +210,7 @@ object RowEncoder {
       val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
 
       StaticInvoke(
-        ArrayBasedMapData,
+        ArrayBasedMapData.getClass,
         ObjectType(classOf[Map[_, _]]),
         "toScalaMap",
         keyData :: valueData :: Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/bd33d4ee/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e6ab9a3..749936c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -44,16 +44,14 @@ import org.apache.spark.sql.types._
  *                      of calling the function.
  */
 case class StaticInvoke(
-    staticObject: Any,
+    staticObject: Class[_],
     dataType: DataType,
     functionName: String,
     arguments: Seq[Expression] = Nil,
     propagateNull: Boolean = true) extends Expression {
 
-  val objectName = staticObject match {
-    case c: Class[_] => c.getName
-    case other => other.getClass.getName.stripSuffix("$")
-  }
+  val objectName = staticObject.getName.stripSuffix("$")
+
   override def nullable: Boolean = true
   override def children: Seq[Expression] = arguments
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bd33d4ee/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 383a2d0..0dbaeb8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -39,6 +39,7 @@ import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructType;
 
 import static org.apache.spark.sql.functions.*;
@@ -608,6 +609,44 @@ public class JavaDatasetSuite implements Serializable {
     }
   }
 
+  public class SimpleJavaBean2 implements Serializable {
+    private Timestamp a;
+    private Date b;
+    private java.math.BigDecimal c;
+
+    public Timestamp getA() { return a; }
+
+    public void setA(Timestamp a) { this.a = a; }
+
+    public Date getB() { return b; }
+
+    public void setB(Date b) { this.b = b; }
+
+    public java.math.BigDecimal getC() { return c; }
+
+    public void setC(java.math.BigDecimal c) { this.c = c; }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+
+      SimpleJavaBean that = (SimpleJavaBean) o;
+
+      if (!a.equals(that.a)) return false;
+      if (!b.equals(that.b)) return false;
+      return c.equals(that.c);
+    }
+
+    @Override
+    public int hashCode() {
+      int result = a.hashCode();
+      result = 31 * result + b.hashCode();
+      result = 31 * result + c.hashCode();
+      return result;
+    }
+  }
+
   public class NestedJavaBean implements Serializable {
     private SimpleJavaBean a;
 
@@ -689,4 +728,17 @@ public class JavaDatasetSuite implements Serializable {
       .as(Encoders.bean(SimpleJavaBean.class));
     Assert.assertEquals(data, ds3.collectAsList());
   }
+
+  @Test
+  public void testJavaBeanEncoder2() {
+    // This is a regression test of SPARK-12404
+    OuterScopes.addOuterScope(this);
+    SimpleJavaBean2 obj = new SimpleJavaBean2();
+    obj.setA(new Timestamp(0));
+    obj.setB(new Date(0));
+    obj.setC(java.math.BigDecimal.valueOf(1));
+    Dataset<SimpleJavaBean2> ds =
+      context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class));
+    ds.collect();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bd33d4ee/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 542e4d6..c6b3991 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.sql.{Date, Timestamp}
 
 import scala.language.postfixOps
 
@@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       1, 1, 1)
   }
 
+
+  test("SPARK-12404: Datatype Helper Serializablity") {
+    val ds = sparkContext.parallelize((
+          new Timestamp(0),
+          new Date(0),
+          java.math.BigDecimal.valueOf(1),
+          scala.math.BigDecimal(1)) :: Nil).toDS()
+
+    ds.collect()
+  }
+
   test("collect, first, and take should use encoders for serialization") {
     val item = NonSerializableCaseClass("abcd")
     val ds = Seq(item).toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message