spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-11797][SQL] collect, first, and take should use encoders for serialization
Date Wed, 18 Nov 2015 05:41:01 GMT
Repository: spark
Updated Branches:
  refs/heads/master 98be8169f -> 91f4b6f2d


[SPARK-11797][SQL] collect, first, and take should use encoders for serialization

They were previously using Spark's default serializer for serialization.

Author: Reynold Xin <rxin@databricks.com>

Closes #9787 from rxin/SPARK-11797.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91f4b6f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91f4b6f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91f4b6f2

Branch: refs/heads/master
Commit: 91f4b6f2db12650dfc33a576803ba8aeccf935dd
Parents: 98be816
Author: Reynold Xin <rxin@databricks.com>
Authored: Tue Nov 17 21:40:58 2015 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Tue Nov 17 21:40:58 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 17 +++++++----
 .../org/apache/spark/sql/DatasetSuite.scala     | 30 +++++++++++++++++++-
 2 files changed, 41 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/91f4b6f2/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index bd01dd4..718ed81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
 import org.apache.spark.api.java.function._
+import org.apache.spark.sql.catalyst.InternalRow
 
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
@@ -199,7 +200,6 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
-    encoderFor[T].assertUnresolved()
     new Dataset[U](
       sqlContext,
       MapPartitions[T, U](
@@ -519,7 +519,7 @@ class Dataset[T] private[sql](
    * Returns the first element in this [[Dataset]].
    * @since 1.6.0
    */
-  def first(): T = rdd.first()
+  def first(): T = take(1).head
 
   /**
    * Returns an array that contains all the elements in this [[Dataset]].
@@ -530,7 +530,14 @@ class Dataset[T] private[sql](
    * For Java API, use [[collectAsList]].
    * @since 1.6.0
    */
-  def collect(): Array[T] = rdd.collect()
+  def collect(): Array[T] = {
+    // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
+    // to convert the rows into objects of type T.
+    val tEnc = resolvedTEncoder
+    val input = queryExecution.analyzed.output
+    val bound = tEnc.bind(input)
+    queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
+  }
 
   /**
    * Returns an array that contains all the elements in this [[Dataset]].
@@ -541,7 +548,7 @@ class Dataset[T] private[sql](
    * For Java API, use [[collectAsList]].
    * @since 1.6.0
    */
-  def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
+  def collectAsList(): java.util.List[T] = collect().toSeq.asJava
 
   /**
    * Returns the first `num` elements of this [[Dataset]] as an array.
@@ -551,7 +558,7 @@ class Dataset[T] private[sql](
    *
    * @since 1.6.0
    */
-  def take(num: Int): Array[T] = rdd.take(num)
+  def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
 
   /**
    * Returns the first `num` elements of this [[Dataset]] as an array.

http://git-wip-us.apache.org/repos/asf/spark/blob/91f4b6f2/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a392234..ea29428 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.io.{ObjectInput, ObjectOutput, Externalizable}
+
 import scala.language.postfixOps
 
 import org.apache.spark.sql.functions._
@@ -24,6 +26,20 @@ import org.apache.spark.sql.test.SharedSQLContext
 
 case class ClassData(a: String, b: Int)
 
+/**
+ * A class used to test serialization using encoders. This class throws exceptions when using
+ * Java serialization -- so the only way it can be "serialized" is through our encoders.
+ */
+case class NonSerializableCaseClass(value: String) extends Externalizable {
+  override def readExternal(in: ObjectInput): Unit = {
+    throw new UnsupportedOperationException
+  }
+
+  override def writeExternal(out: ObjectOutput): Unit = {
+    throw new UnsupportedOperationException
+  }
+}
+
 class DatasetSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
@@ -41,6 +57,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       1, 1, 1)
   }
 
+  test("collect, first, and take should use encoders for serialization") {
+    val item = NonSerializableCaseClass("abcd")
+    val ds = Seq(item).toDS()
+    assert(ds.collect().head == item)
+    assert(ds.collectAsList().get(0) == item)
+    assert(ds.first() == item)
+    assert(ds.take(1).head == item)
+    assert(ds.takeAsList(1).get(0) == item)
+  }
+
   test("as tuple") {
     val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
     checkAnswer(
@@ -75,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
   ignore("Dataset should set the resolved encoders internally for maps") {
     // TODO: Enable this once we fix SPARK-11793.
+    // We inject a group by here to make sure this test case is future proof
+    // when we implement better pipelining and local execution mode.
     val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
         .map(c => ClassData(c.a, c.b + 1))
         .groupBy(p => p).count()
@@ -219,7 +247,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy function, fatMap") {
+  test("groupBy function, flatMap") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy(v => (v._1, "word"))
     val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString)
}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message