spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-9674][SPARK-9667] Remove SparkSqlSerializer2
Date Fri, 07 Aug 2015 18:02:56 GMT
Repository: spark
Updated Branches:
  refs/heads/master ebfd91c54 -> 76eaa7018


[SPARK-9674][SPARK-9667] Remove SparkSqlSerializer2

It is now subsumed by various Tungsten operators.

Author: Reynold Xin <rxin@databricks.com>

Closes #7981 from rxin/SPARK-9674 and squashes the following commits:

144f96e [Reynold Xin] Re-enable test
58b7332 [Reynold Xin] Disable failing list.
fb797e3 [Reynold Xin] Match all UDTs.
be9f243 [Reynold Xin] Updated if.
71fc99c [Reynold Xin] [SPARK-9674][SPARK-9667] Remove GeneratedAggregate & SparkSqlSerializer2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/76eaa701
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/76eaa701
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/76eaa701

Branch: refs/heads/master
Commit: 76eaa701833a2ff23b50147d70ced41e85719572
Parents: ebfd91c
Author: Reynold Xin <rxin@databricks.com>
Authored: Fri Aug 7 11:02:53 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Fri Aug 7 11:02:53 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/SQLConf.scala    |   6 -
 .../apache/spark/sql/execution/Exchange.scala   |  48 +--
 .../sql/execution/SparkSqlSerializer2.scala     | 426 -------------------
 .../execution/SparkSqlSerializer2Suite.scala    | 221 ----------
 4 files changed, 24 insertions(+), 677 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/76eaa701/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index ef35c13..45d3d8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -416,10 +416,6 @@ private[spark] object SQLConf {
   val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
     defaultValue = Some(true), doc = "<TODO>")
 
-  val USE_SQL_SERIALIZER2 = booleanConf(
-    "spark.sql.useSerializer2",
-    defaultValue = Some(true), isPublic = false)
-
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
 
   private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
 
-  private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
-
   private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
 
   private[spark] def defaultSizeInBytes: Long =

http://git-wip-us.apache.org/repos/asf/spark/blob/76eaa701/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 6ea5eee..60087f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.UserDefinedType
 import org.apache.spark.util.MutablePair
 import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
 
@@ -39,21 +40,34 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner,
SparkEn
 @DeveloperApi
 case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
 
-  override def outputPartitioning: Partitioning = newPartitioning
-
-  override def output: Seq[Attribute] = child.output
-
-  override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+  override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange"
 
-  override def canProcessSafeRows: Boolean = true
-
-  override def canProcessUnsafeRows: Boolean = {
+  /**
+   * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type.
+   * This only happens with the old aggregate implementation and should be removed in 1.6.
+   */
+  private lazy val tungstenMode: Boolean = {
+    val unserializableUDT = child.schema.exists(_.dataType match {
+      case _: UserDefinedType[_] => true
+      case _ => false
+    })
     // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead
to
     // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to
     // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed.
-    !newPartitioning.isInstanceOf[RangePartitioning]
+    !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning]
   }
 
+  override def outputPartitioning: Partitioning = newPartitioning
+
+  override def output: Seq[Attribute] = child.output
+
+  // This setting is somewhat counterintuitive:
+  // If the schema works with UnsafeRow, then we tell the planner that we don't support safe
row,
+  // so the planner inserts a converter to convert data into UnsafeRow if needed.
+  override def outputsUnsafeRows: Boolean = tungstenMode
+  override def canProcessSafeRows: Boolean = !tungstenMode
+  override def canProcessUnsafeRows: Boolean = tungstenMode
+
   /**
    * Determines whether records must be defensively copied before being sent to the shuffle.
    * Several of Spark's shuffle components will buffer deserialized Java objects in memory.
The
@@ -124,23 +138,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan)
extends Una
 
   private val serializer: Serializer = {
     val rowDataTypes = child.output.map(_.dataType).toArray
-    // It is true when there is no field that needs to be write out.
-    // For now, we will not use SparkSqlSerializer2 when noField is true.
-    val noField = rowDataTypes == null || rowDataTypes.length == 0
-
-    val useSqlSerializer2 =
-        child.sqlContext.conf.useSqlSerializer2 &&   // SparkSqlSerializer2 is enabled.
-        SparkSqlSerializer2.support(rowDataTypes) &&  // The schema of row is supported.
-        !noField
-
-    if (child.outputsUnsafeRows) {
-      logInfo("Using UnsafeRowSerializer.")
+    if (tungstenMode) {
       new UnsafeRowSerializer(child.output.size)
-    } else if (useSqlSerializer2) {
-      logInfo("Using SparkSqlSerializer2.")
-      new SparkSqlSerializer2(rowDataTypes)
     } else {
-      logInfo("Using SparkSqlSerializer.")
       new SparkSqlSerializer(sparkConf)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/76eaa701/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
deleted file mode 100644
index e811f1d..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ /dev/null
@@ -1,426 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io._
-import java.math.{BigDecimal, BigInteger}
-import java.nio.ByteBuffer
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.Logging
-import org.apache.spark.serializer._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed
in
- * its `writeObject` are [[Product2]]. The serialization functions for the key and value
of the
- * [[Product2]] are constructed based on their schemata.
- * The benefit of this serialization stream is that compared with general-purpose serializers
like
- * Kryo and Java serializer, it can significantly reduce the size of serialized and has a
lower
- * allocation cost, which can benefit the shuffle operation. Right now, its main limitations
are:
- *  1. It does not support complex types, i.e. Map, Array, and Struct.
- *  2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when
- *     [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used
because
- *     the objects passed in the serializer are not in the type of [[Product2]]. Also also
see
- *     the comment of the `serializer` method in [[Exchange]] for more information on it.
- */
-private[sql] class Serializer2SerializationStream(
-    rowSchema: Array[DataType],
-    out: OutputStream)
-  extends SerializationStream with Logging {
-
-  private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
-  private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)
-
-  override def writeObject[T: ClassTag](t: T): SerializationStream = {
-    val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]]
-    writeKey(kv._1)
-    writeValue(kv._2)
-
-    this
-  }
-
-  override def writeKey[T: ClassTag](t: T): SerializationStream = {
-    // No-op.
-    this
-  }
-
-  override def writeValue[T: ClassTag](t: T): SerializationStream = {
-    writeRowFunc(t.asInstanceOf[InternalRow])
-    this
-  }
-
-  def flush(): Unit = {
-    rowOut.flush()
-  }
-
-  def close(): Unit = {
-    rowOut.close()
-  }
-}
-
-/**
- * The corresponding deserialization stream for [[Serializer2SerializationStream]].
- */
-private[sql] class Serializer2DeserializationStream(
-    rowSchema: Array[DataType],
-    in: InputStream)
-  extends DeserializationStream with Logging  {
-
-  private val rowIn = new DataInputStream(new BufferedInputStream(in))
-
-  private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
-    if (schema == null) {
-      () => null
-    } else {
-      // It is safe to reuse the mutable row.
-      val mutableRow = new SpecificMutableRow(schema)
-      () => mutableRow
-    }
-  }
-
-  // Functions used to return rows for key and value.
-  private val getRow = rowGenerator(rowSchema)
-  // Functions used to read a serialized row from the InputStream and deserialize it.
-  private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema,
rowIn)
-
-  override def readObject[T: ClassTag](): T = {
-    readValue()
-  }
-
-  override def readKey[T: ClassTag](): T = {
-    null.asInstanceOf[T] // intentionally left blank.
-  }
-
-  override def readValue[T: ClassTag](): T = {
-    readRowFunc(getRow()).asInstanceOf[T]
-  }
-
-  override def close(): Unit = {
-    rowIn.close()
-  }
-}
-
-private[sql] class SparkSqlSerializer2Instance(
-    rowSchema: Array[DataType])
-  extends SerializerInstance {
-
-  def serialize[T: ClassTag](t: T): ByteBuffer =
-    throw new UnsupportedOperationException("Not supported.")
-
-  def deserialize[T: ClassTag](bytes: ByteBuffer): T =
-    throw new UnsupportedOperationException("Not supported.")
-
-  def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
-    throw new UnsupportedOperationException("Not supported.")
-
-  def serializeStream(s: OutputStream): SerializationStream = {
-    new Serializer2SerializationStream(rowSchema, s)
-  }
-
-  def deserializeStream(s: InputStream): DeserializationStream = {
-    new Serializer2DeserializationStream(rowSchema, s)
-  }
-}
-
-/**
- * SparkSqlSerializer2 is a special serializer that creates serialization function and
- * deserialization function based on the schema of data. It assumes that values passed in
- * are Rows.
- */
-private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
-  extends Serializer
-  with Logging
-  with Serializable{
-
-  def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)
-
-  override def supportsRelocationOfSerializedObjects: Boolean = {
-    // SparkSqlSerializer2 is stateless and writes no stream headers
-    true
-  }
-}
-
-private[sql] object SparkSqlSerializer2 {
-
-  final val NULL = 0
-  final val NOT_NULL = 1
-
-  /**
-   * Check if rows with the given schema can be serialized with ShuffleSerializer.
-   * Right now, we do not support a schema having complex types or UDTs, or all data types
-   * of fields are NullTypes.
-   */
-  def support(schema: Array[DataType]): Boolean = {
-    if (schema == null) return true
-
-    var allNullTypes = true
-    var i = 0
-    while (i < schema.length) {
-      schema(i) match {
-        case NullType => // Do nothing
-        case udt: UserDefinedType[_] =>
-          allNullTypes = false
-          return false
-        case array: ArrayType =>
-          allNullTypes = false
-          return false
-        case map: MapType =>
-          allNullTypes = false
-          return false
-        case struct: StructType =>
-          allNullTypes = false
-          return false
-        case _ =>
-          allNullTypes = false
-      }
-      i += 1
-    }
-
-    // If types of fields are all NullTypes, we return false.
-    // Otherwise, we return true.
-    return !allNullTypes
-  }
-
-  /**
-   * The util function to create the serialization function based on the given schema.
-   */
-  def createSerializationFunction(schema: Array[DataType], out: DataOutputStream)
-    : InternalRow => Unit = {
-    (row: InternalRow) =>
-      // If the schema is null, the returned function does nothing when it get called.
-      if (schema != null) {
-        var i = 0
-        while (i < schema.length) {
-          schema(i) match {
-            // When we write values to the underlying stream, we also first write the null
byte
-            // first. Then, if the value is not null, we write the contents out.
-
-            case NullType => // Write nothing.
-
-            case BooleanType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeBoolean(row.getBoolean(i))
-              }
-
-            case ByteType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeByte(row.getByte(i))
-              }
-
-            case ShortType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeShort(row.getShort(i))
-              }
-
-            case IntegerType | DateType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeInt(row.getInt(i))
-              }
-
-            case LongType | TimestampType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeLong(row.getLong(i))
-              }
-
-            case FloatType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeFloat(row.getFloat(i))
-              }
-
-            case DoubleType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                out.writeDouble(row.getDouble(i))
-              }
-
-            case StringType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                val bytes = row.getUTF8String(i).getBytes
-                out.writeInt(bytes.length)
-                out.write(bytes)
-              }
-
-            case BinaryType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                val bytes = row.getBinary(i)
-                out.writeInt(bytes.length)
-                out.write(bytes)
-              }
-
-            case decimal: DecimalType =>
-              if (row.isNullAt(i)) {
-                out.writeByte(NULL)
-              } else {
-                out.writeByte(NOT_NULL)
-                val value = row.getDecimal(i, decimal.precision, decimal.scale)
-                val javaBigDecimal = value.toJavaBigDecimal
-                // First, write out the unscaled value.
-                val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
-                out.writeInt(bytes.length)
-                out.write(bytes)
-                // Then, write out the scale.
-                out.writeInt(javaBigDecimal.scale())
-              }
-          }
-          i += 1
-        }
-      }
-  }
-
-  /**
-   * The util function to create the deserialization function based on the given schema.
-   */
-  def createDeserializationFunction(
-      schema: Array[DataType],
-      in: DataInputStream): (MutableRow) => InternalRow = {
-    if (schema == null) {
-      (mutableRow: MutableRow) => null
-    } else {
-      (mutableRow: MutableRow) => {
-        var i = 0
-        while (i < schema.length) {
-          schema(i) match {
-            // When we read values from the underlying stream, we also first read the null
byte
-            // first. Then, if the value is not null, we update the field of the mutable
row.
-
-            case NullType => mutableRow.setNullAt(i) // Read nothing.
-
-            case BooleanType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setBoolean(i, in.readBoolean())
-              }
-
-            case ByteType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setByte(i, in.readByte())
-              }
-
-            case ShortType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setShort(i, in.readShort())
-              }
-
-            case IntegerType | DateType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setInt(i, in.readInt())
-              }
-
-            case LongType | TimestampType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setLong(i, in.readLong())
-              }
-
-            case FloatType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setFloat(i, in.readFloat())
-              }
-
-            case DoubleType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow.setDouble(i, in.readDouble())
-              }
-
-            case StringType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                val length = in.readInt()
-                val bytes = new Array[Byte](length)
-                in.readFully(bytes)
-                mutableRow.update(i, UTF8String.fromBytes(bytes))
-              }
-
-            case BinaryType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                val length = in.readInt()
-                val bytes = new Array[Byte](length)
-                in.readFully(bytes)
-                mutableRow.update(i, bytes)
-              }
-
-            case decimal: DecimalType =>
-              if (in.readByte() == NULL) {
-                mutableRow.setNullAt(i)
-              } else {
-                // First, read in the unscaled value.
-                val length = in.readInt()
-                val bytes = new Array[Byte](length)
-                in.readFully(bytes)
-                val unscaledVal = new BigInteger(bytes)
-                // Then, read the scale.
-                val scale = in.readInt()
-                // Finally, create the Decimal object and set it in the row.
-                mutableRow.update(i,
-                  Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale))
-              }
-          }
-          i += 1
-        }
-
-        mutableRow
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/76eaa701/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
deleted file mode 100644
index 7978ed5..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ /dev/null
@@ -1,221 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.sql.{Timestamp, Date}
-
-import org.apache.spark.sql.test.TestSQLContext
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.{ShuffleDependency, SparkFunSuite}
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
-
-class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
-  // Make sure that we will not use serializer2 for unsupported data types.
-  def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
-    val testName =
-      s"${if (dataType == null) null else dataType.toString} is " +
-        s"${if (isSupported) "supported" else "unsupported"}"
-
-    test(testName) {
-      assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported)
-    }
-  }
-
-  checkSupported(null, isSupported = true)
-  checkSupported(BooleanType, isSupported = true)
-  checkSupported(ByteType, isSupported = true)
-  checkSupported(ShortType, isSupported = true)
-  checkSupported(IntegerType, isSupported = true)
-  checkSupported(LongType, isSupported = true)
-  checkSupported(FloatType, isSupported = true)
-  checkSupported(DoubleType, isSupported = true)
-  checkSupported(DateType, isSupported = true)
-  checkSupported(TimestampType, isSupported = true)
-  checkSupported(StringType, isSupported = true)
-  checkSupported(BinaryType, isSupported = true)
-  checkSupported(DecimalType(10, 5), isSupported = true)
-  checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true)
-
-  // If NullType is the only data type in the schema, we do not support it.
-  checkSupported(NullType, isSupported = false)
-  // For now, ArrayType, MapType, and StructType are not supported.
-  checkSupported(ArrayType(DoubleType, true), isSupported = false)
-  checkSupported(ArrayType(StringType, false), isSupported = false)
-  checkSupported(MapType(IntegerType, StringType, true), isSupported = false)
-  checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false)
-  checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false)
-  // UDTs are not supported right now.
-  checkSupported(new MyDenseVectorUDT, isSupported = false)
-}
-
-abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
-  var allColumns: String = _
-  val serializerClass: Class[Serializer] =
-    classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
-  var numShufflePartitions: Int = _
-  var useSerializer2: Boolean = _
-
-  protected lazy val ctx = TestSQLContext
-
-  override def beforeAll(): Unit = {
-    numShufflePartitions = ctx.conf.numShufflePartitions
-    useSerializer2 = ctx.conf.useSqlSerializer2
-
-    ctx.sql("set spark.sql.useSerializer2=true")
-
-    val supportedTypes =
-      Seq(StringType, BinaryType, NullType, BooleanType,
-        ByteType, ShortType, IntegerType, LongType,
-        FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5),
-        DateType, TimestampType)
-
-    val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
-      StructField(s"col$index", dataType, true)
-    }
-    allColumns = fields.map(_.name).mkString(",")
-    val schema = StructType(fields)
-
-    // Create a RDD with all data types supported by SparkSqlSerializer2.
-    val rdd =
-      ctx.sparkContext.parallelize((1 to 1000), 10).map { i =>
-        Row(
-          s"str${i}: test serializer2.",
-          s"binary${i}: test serializer2.".getBytes("UTF-8"),
-          null,
-          i % 2 == 0,
-          i.toByte,
-          i.toShort,
-          i,
-          Long.MaxValue - i.toLong,
-          (i + 0.25).toFloat,
-          (i + 0.75),
-          BigDecimal(Long.MaxValue.toString + ".12345"),
-          new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
-          new Date(i),
-          new Timestamp(i))
-      }
-
-    ctx.createDataFrame(rdd, schema).registerTempTable("shuffle")
-
-    super.beforeAll()
-  }
-
-  override def afterAll(): Unit = {
-    ctx.dropTempTable("shuffle")
-    ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
-    ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2")
-    super.afterAll()
-  }
-
-  def checkSerializer[T <: Serializer](
-      executedPlan: SparkPlan,
-      expectedSerializerClass: Class[T]): Unit = {
-    executedPlan.foreach {
-      case exchange: Exchange =>
-        val shuffledRDD = exchange.execute()
-        val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_,
_, _]]
-        val serializerNotSetMessage =
-          s"Expected $expectedSerializerClass as the serializer of Exchange. " +
-          s"However, the serializer was not set."
-        val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage))
-        val isExpectedSerializer =
-          serializer.getClass == expectedSerializerClass ||
-            serializer.getClass == classOf[UnsafeRowSerializer]
-        val wrongSerializerErrorMessage =
-          s"Expected ${expectedSerializerClass.getCanonicalName} or " +
-            s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " +
-            s"${serializer.getClass.getCanonicalName} is used."
-        assert(isExpectedSerializer, wrongSerializerErrorMessage)
-      case _ => // Ignore other nodes.
-    }
-  }
-
-  test("key schema and value schema are not nulls") {
-    val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
-    checkSerializer(df.queryExecution.executedPlan, serializerClass)
-    checkAnswer(
-      df,
-      ctx.table("shuffle").collect())
-  }
-
-  test("key schema is null") {
-    val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
-    val df = ctx.sql(s"SELECT $aggregations FROM shuffle")
-    checkSerializer(df.queryExecution.executedPlan, serializerClass)
-    checkAnswer(
-      df,
-      Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
-  }
-
-  test("value schema is null") {
-    val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0")
-    checkSerializer(df.queryExecution.executedPlan, serializerClass)
-    assert(df.map(r => r.getString(0)).collect().toSeq ===
-      ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
-  }
-
-  test("no map output field") {
-    val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle")
-    checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
-  }
-
-  test("types of fields are all NullTypes") {
-    // Test range partitioning code path.
-    val nulls = ctx.sql(s"SELECT null as a, null as b, null as c")
-    val df = nulls.unionAll(nulls).sort("a")
-    checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
-    checkAnswer(
-      df,
-      Row(null, null, null) :: Row(null, null, null) :: Nil)
-
-    // Test hash partitioning code path.
-    val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle")
-    checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer])
-    checkAnswer(
-      oneRow,
-      Row(null, null, null))
-  }
-}
-
-/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
-class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    // Sort merge will not be triggered.
-    val bypassMergeThreshold =
-      ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-    ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
-  }
-}
-
-/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
-class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-    // To trigger the sort merge.
-    val bypassMergeThreshold =
-      ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-    ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message