spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list)
Date Fri, 06 Jan 2017 07:05:31 GMT
Repository: spark
Updated Branches:
  refs/heads/master bcc510b02 -> 903bb8e8a


[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException
(converting sequence to list)

## What changes were proposed in this pull request?

Added a `to` call at the end of the code generated by `ScalaReflection.deserializerFor` if
the requested type is not a supertype of `WrappedArray[_]` that uses `CanBuildFrom[_, _, _]`
to convert result into an arbitrary subtype of `Seq[_]`.

Care was taken to preserve the original deserialization where it is possible to avoid the
overhead of conversion in cases where it is not needed

`ScalaReflection.serializerFor` could already be used to serialize any `Seq[_]` so it was
not altered

`SQLImplicits` had to be altered and new implicit encoders added to permit serialization of
other sequence types

Also fixes [SPARK-16815] Dataset[List[T]] leads to ArrayStoreException

## How was this patch tested?
```bash
./build/mvn -DskipTests clean package && ./dev/run-tests
```

Also manual execution of the following sets of commands in the Spark shell:
```scala
case class TestCC(key: Int, letters: List[String])

val ds1 = sc.makeRDD(Seq(
(List("D")),
(List("S","H")),
(List("F","H")),
(List("D","L","L"))
)).map(x=>(x.length,x)).toDF("key","letters").as[TestCC]

val test1=ds1.map{_.key}
test1.show
```

```scala
case class X(l: List[String])
spark.createDataset(Seq(List("A"))).map(X).show
```

```scala
spark.sqlContext.createDataset(sc.parallelize(List(1) :: Nil)).collect
```

After adding arbitrary sequence support also tested with the following commands:

```scala
case class QueueClass(q: scala.collection.immutable.Queue[Int])

spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x:
_*))).map(_.q.dequeue).collect
```

Author: Michal Senkyr <mike.senkyr@gmail.com>

Closes #16240 from michalsenkyr/sql-caseclass-list-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/903bb8e8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/903bb8e8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/903bb8e8

Branch: refs/heads/master
Commit: 903bb8e8a2b84b9ea82acbb8ae9d58754862be3a
Parents: bcc510b
Author: Michal Senkyr <mike.senkyr@gmail.com>
Authored: Fri Jan 6 15:05:20 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri Jan 6 15:05:20 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  40 ++++++-
 .../sql/catalyst/ScalaReflectionSuite.scala     |  31 +++++
 .../org/apache/spark/sql/SQLImplicits.scala     | 115 +++++++++++++++----
 .../spark/sql/DatasetPrimitiveSuite.scala       |  67 +++++++++++
 4 files changed, 231 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/903bb8e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index ad218cf..7f7dd51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
           "array",
           ObjectType(classOf[Array[Any]]))
 
-        StaticInvoke(
+        val wrappedArray = StaticInvoke(
           scala.collection.mutable.WrappedArray.getClass,
           ObjectType(classOf[Seq[_]]),
           "make",
           array :: Nil)
 
+        if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
+          wrappedArray
+        } else {
+          // Convert to another type using `to`
+          val cls = mirror.runtimeClass(t.typeSymbol.asClass)
+          import scala.collection.generic.CanBuildFrom
+          import scala.reflect.ClassTag
+
+          // Some canBuildFrom methods take an implicit ClassTag parameter
+          val cbfParams = try {
+            cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
+            StaticInvoke(
+              ClassTag.getClass,
+              ObjectType(classOf[ClassTag[_]]),
+              "apply",
+              StaticInvoke(
+                cls,
+                ObjectType(classOf[Class[_]]),
+                "getClass"
+              ) :: Nil
+            ) :: Nil
+          } catch {
+            case _: NoSuchMethodException => Nil
+          }
+
+          Invoke(
+            wrappedArray,
+            "to",
+            ObjectType(cls),
+            StaticInvoke(
+              cls,
+              ObjectType(classOf[CanBuildFrom[_, _, _]]),
+              "canBuildFrom",
+              cbfParams
+            ) :: Nil
+          )
+        }
+
       case t if t <:< localTypeOf[Map[_, _]] =>
         // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t

http://git-wip-us.apache.org/repos/asf/spark/blob/903bb8e8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 43b6afd..650a353 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
       .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
   }
 
+  test("SPARK 16792: Get correct deserializer for List[_]") {
+    val listDeserializer = deserializerFor[List[Int]]
+    assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
+  }
+
+  test("serialize and deserialize arbitrary sequence types") {
+    import scala.collection.immutable.Queue
+    val queueSerializer = serializerFor[Queue[Int]](BoundReference(
+      0, ObjectType(classOf[Queue[Int]]), nullable = false))
+    assert(queueSerializer.dataType.head.dataType ==
+      ArrayType(IntegerType, containsNull = false))
+    val queueDeserializer = deserializerFor[Queue[Int]]
+    assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
+
+    import scala.collection.mutable.ArrayBuffer
+    val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
+      0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
+    assert(arrayBufferSerializer.dataType.head.dataType ==
+      ArrayType(IntegerType, containsNull = false))
+    val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
+    assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
+
+    // Check whether conversion is skipped when using WrappedArray[_] supertype
+    // (would otherwise needlessly add overhead)
+    import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+    val seqDeserializer = deserializerFor[Seq[Int]]
+    assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
+      scala.collection.mutable.WrappedArray.getClass)
+    assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
+  }
+
   private val dataTypeForComplexData = dataTypeFor[ComplexData]
   private val typeOfComplexData = typeOf[ComplexData]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/903bb8e8/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 872a78b..2caf723 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
  * @since 1.6.0
  */
 @InterfaceStability.Evolving
-abstract class SQLImplicits {
+abstract class SQLImplicits extends LowPrioritySQLImplicits {
 
   protected def _sqlContext: SQLContext
 
@@ -45,9 +45,6 @@ abstract class SQLImplicits {
     }
   }
 
-  /** @since 1.6.0 */
-  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
-
   // Primitives
 
   /** @since 1.6.0 */
@@ -112,33 +109,96 @@ abstract class SQLImplicits {
 
   // Seqs
 
-  /** @since 1.6.1 */
-  implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newIntSequenceEncoder]]
+   */
+  def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newLongSequenceEncoder]]
+   */
+  def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newDoubleSequenceEncoder]]
+   */
+  def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newFloatSequenceEncoder]]
+   */
+  def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newByteSequenceEncoder]]
+   */
+  def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newShortSequenceEncoder]]
+   */
+  def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newBooleanSequenceEncoder]]
+   */
+  def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
-  implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newStringSequenceEncoder]]
+   */
+  def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
 
-  /** @since 1.6.1 */
+  /**
+   * @since 1.6.1
+   * @deprecated use [[newProductSequenceEncoder]]
+   */
   implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
 
+  /** @since 2.2.0 */
+  implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
+  /** @since 2.2.0 */
+  implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
+    ExpressionEncoder()
+
   // Arrays
 
   /** @since 1.6.1 */
@@ -193,3 +253,16 @@ abstract class SQLImplicits {
   implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
 
 }
+
+/**
+ * Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
+ * Conflicting implicits are placed here to disambiguate resolution.
+ *
+ * Reasons for including specific implicits:
+ * newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
+ */
+trait LowPrioritySQLImplicits {
+  /** @since 1.6.0 */
+  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/903bb8e8/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index f8d4c61..6b50cb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -17,10 +17,21 @@
 
 package org.apache.spark.sql
 
+import scala.collection.immutable.Queue
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.sql.test.SharedSQLContext
 
 case class IntClass(value: Int)
 
+case class SeqClass(s: Seq[Int])
+
+case class ListClass(l: List[Int])
+
+case class QueueClass(q: Queue[Int])
+
+case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
+
 package object packageobject {
   case class PackageClass(value: Int)
 }
@@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
     checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
   }
 
+  test("arbitrary sequences") {
+    checkDataset(Seq(Queue(1)).toDS(), Queue(1))
+    checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
+    checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
+    checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
+    checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
+    checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
+    checkDataset(Seq(Queue(true)).toDS(), Queue(true))
+    checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
+    checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))
+
+    checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
+    checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
+    checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
+    checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
+    checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
+    checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
+    checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
+    checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
+    checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
+  }
+
+  test("sequence and product combinations") {
+    // Case classes
+    checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1)))
+    checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1))))
+    checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1))))
+    checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1))))
+
+    checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1)))
+    checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1))))
+    checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1))))
+    checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1))))
+
+    checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1)))
+    checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1))))
+    checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1))))
+    checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1))))
+
+    val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3)))
+    checkDataset(Seq(complex).toDS(), complex)
+    checkDataset(Seq(Seq(complex)).toDS(), Seq(complex))
+    checkDataset(Seq(List(complex)).toDS(), List(complex))
+    checkDataset(Seq(Queue(complex)).toDS(), Queue(complex))
+
+    // Tuples
+    checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2))
+    checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2))
+    checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(),
+      List(Seq("test1") -> List(Queue("test2"))))
+
+    // Complex
+    checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(),
+      ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
+  }
+
   test("package objects") {
     import packageobject._
     checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message