spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-11802][SQL] Kryo-based encoder for opaque types in Datasets
Date Wed, 18 Nov 2015 08:09:31 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8019f66df -> 5e2b44474


[SPARK-11802][SQL] Kryo-based encoder for opaque types in Datasets

I also found a bug with self-joins returning incorrect results in the Dataset API. Two test
cases attached and filed SPARK-11803.

Author: Reynold Xin <rxin@databricks.com>

Closes #9789 from rxin/SPARK-11802.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e2b4447
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e2b4447
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e2b4447

Branch: refs/heads/master
Commit: 5e2b44474c2b838bebeffe5ba5cd72961b0cd31e
Parents: 8019f66
Author: Reynold Xin <rxin@databricks.com>
Authored: Wed Nov 18 00:09:29 2015 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Wed Nov 18 00:09:29 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Encoder.scala    | 31 ++++++++-
 .../catalyst/encoders/ExpressionEncoder.scala   |  4 +-
 .../sql/catalyst/encoders/ProductEncoder.scala  |  2 +-
 .../sql/catalyst/expressions/objects.scala      | 69 ++++++++++++++++++-
 .../catalyst/encoders/FlatEncoderSuite.scala    | 18 +++++
 .../scala/org/apache/spark/sql/Dataset.scala    |  6 ++
 .../org/apache/spark/sql/GroupedDataset.scala   |  1 -
 .../org/apache/spark/sql/DatasetSuite.scala     | 70 +++++++++++++++-----
 8 files changed, 178 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index c8b017e..79c2255 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql
 
-import scala.reflect.ClassTag
+import scala.reflect.{ClassTag, classTag}
 
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo}
+import org.apache.spark.sql.types._
 
 /**
  * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -37,7 +38,33 @@ trait Encoder[T] extends Serializable {
   def clsTag: ClassTag[T]
 }
 
+/**
+ * Methods for creating encoders.
+ */
 object Encoders {
+
+  /**
+   * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
+   * This encoder maps T into a single byte array (binary) field.
+   */
+  def kryo[T: ClassTag]: Encoder[T] = {
+    val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable =
true))
+    val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T])
+    ExpressionEncoder[T](
+      schema = new StructType().add("value", BinaryType),
+      flat = true,
+      toRowExpressions = Seq(ser),
+      fromRowExpression = deser,
+      clsTag = classTag[T]
+    )
+  }
+
+  /**
+   * Creates an encoder that serializes objects of type T using Kryo.
+   * This encoder maps T into a single byte array (binary) field.
+   */
+  def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+
   def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
   def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
   def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9a1a8f5..b977f27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -161,7 +161,9 @@ case class ExpressionEncoder[T](
 
   @transient
   private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
-  private val inputRow = new GenericMutableRow(1)
+
+  @transient
+  private lazy val inputRow = new GenericMutableRow(1)
 
   @transient
   private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression
:: Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index 414adb2..55c4ee1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -230,7 +230,7 @@ object ProductEncoder {
           Invoke(inputObject, "booleanValue", BooleanType)
 
         case other =>
-          throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
+          throw new UnsupportedOperationException(s"Encoder for type $other is not supported")
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 5cd19de..489c612 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.language.existentials
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer}
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
 import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
 import org.apache.spark.sql.catalyst.util.GenericArrayData
-
-import scala.language.existentials
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.types._
@@ -514,3 +516,64 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType:
DataTy
     """
   }
 }
+
+/** Serializes an input object using Kryo serializer. */
+case class SerializeWithKryo(child: Expression) extends UnaryExpression {
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String
= {
+    val input = child.gen(ctx)
+    val kryo = ctx.freshName("kryoSerializer")
+    val kryoClass = classOf[KryoSerializer].getName
+    val kryoInstanceClass = classOf[KryoSerializerInstance].getName
+    val sparkConfClass = classOf[SparkConf].getName
+    ctx.addMutableState(
+      kryoInstanceClass,
+      kryo,
+      s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")
+
+    s"""
+      ${input.code}
+      final boolean ${ev.isNull} = ${input.isNull};
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${ev.value} = $kryo.serialize(${input.value}, null).array();
+      }
+     """
+  }
+
+  override def dataType: DataType = BinaryType
+}
+
+/**
+ * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit
+ * parameter because TreeNode cannot copy implicit parameters.
+ */
+case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression
{
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String
= {
+    val input = child.gen(ctx)
+    val kryo = ctx.freshName("kryoSerializer")
+    val kryoClass = classOf[KryoSerializer].getName
+    val kryoInstanceClass = classOf[KryoSerializerInstance].getName
+    val sparkConfClass = classOf[SparkConf].getName
+    ctx.addMutableState(
+      kryoInstanceClass,
+      kryo,
+      s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();")
+
+    s"""
+      ${input.code}
+      final boolean ${ev.isNull} = ${input.isNull};
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${ev.value} = (${ctx.javaType(dataType)})
+          $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null);
+      }
+     """
+  }
+
+  override def dataType: DataType = ObjectType(tag.runtimeClass)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
index 55821c4..2729db8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.encoders
 
 import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.Encoders
 
 class FlatEncoderSuite extends ExpressionEncoderSuite {
   encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
@@ -71,4 +72,21 @@ class FlatEncoderSuite extends ExpressionEncoderSuite {
   encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with
null")
   encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
     FlatEncoder[Map[Int, Map[String, Int]]], "map of map")
+
+  // Kryo encoders
+  encodeDecodeTest(
+    "hello",
+    Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]],
+    "kryo string")
+  encodeDecodeTest(
+    new NotJavaSerializable(15),
+    Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]],
+    "kryo object serialization")
+}
+
+
+class NotJavaSerializable(val value: Int) {
+  override def equals(other: Any): Boolean = {
+    this.value == other.asInstanceOf[NotJavaSerializable].value
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 718ed81..817c20f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -147,6 +147,12 @@ class Dataset[T] private[sql](
     }
   }
 
+  /**
+   * Returns the number of elements in the [[Dataset]].
+   * @since 1.6.0
+   */
+  def count(): Long = toDF().count()
+
   /* *********************** *
    *  Functional Operations  *
    * *********************** */

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 467cd42..c66162e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql
 
-
 import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental

http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index ea29428..a522894 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,21 +24,6 @@ import scala.language.postfixOps
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 
-case class ClassData(a: String, b: Int)
-
-/**
- * A class used to test serialization using encoders. This class throws exceptions when using
- * Java serialization -- so the only way it can be "serialized" is through our encoders.
- */
-case class NonSerializableCaseClass(value: String) extends Externalizable {
-  override def readExternal(in: ObjectInput): Unit = {
-    throw new UnsupportedOperationException
-  }
-
-  override def writeExternal(out: ObjectOutput): Unit = {
-    throw new UnsupportedOperationException
-  }
-}
 
 class DatasetSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -362,8 +347,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     checkAnswer(joined, ("2", 2))
   }
 
+  ignore("self join") {
+    val ds = Seq("1", "2").toDS().as("a")
+    val joined = ds.joinWith(ds, lit(true))
+    checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
+  }
+
   test("toString") {
     val ds = Seq((1, 2)).toDS()
     assert(ds.toString == "[_1: int, _2: int]")
   }
+
+  test("kryo encoder") {
+    implicit val kryoEncoder = Encoders.kryo[KryoData]
+    val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+
+    assert(ds.groupBy(p => p).count().collect().toSeq ==
+      Seq((KryoData(1), 1L), (KryoData(2), 1L)))
+  }
+
+  ignore("kryo encoder self join") {
+    implicit val kryoEncoder = Encoders.kryo[KryoData]
+    val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2)))
+    assert(ds.joinWith(ds, lit(true)).collect().toSet ==
+      Set(
+        (KryoData(1), KryoData(1)),
+        (KryoData(1), KryoData(2)),
+        (KryoData(2), KryoData(1)),
+        (KryoData(2), KryoData(2))))
+  }
+}
+
+
+case class ClassData(a: String, b: Int)
+
+/**
+ * A class used to test serialization using encoders. This class throws exceptions when using
+ * Java serialization -- so the only way it can be "serialized" is through our encoders.
+ */
+case class NonSerializableCaseClass(value: String) extends Externalizable {
+  override def readExternal(in: ObjectInput): Unit = {
+    throw new UnsupportedOperationException
+  }
+
+  override def writeExternal(out: ObjectOutput): Unit = {
+    throw new UnsupportedOperationException
+  }
+}
+
+/** Used to test Kryo encoder. */
+class KryoData(val a: Int) {
+  override def equals(other: Any): Boolean = {
+    a == other.asInstanceOf[KryoData].a
+  }
+  override def hashCode: Int = a
+  override def toString: String = s"KryoData($a)"
+}
+
+object KryoData {
+  def apply(a: Int): KryoData = new KryoData(a)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message