spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hvanhov...@apache.org
Subject spark git commit: [SPARK-17490][SQL] Optimize SerializeFromObject() for a primitive array
Date Mon, 07 Nov 2016 23:17:53 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.1 d1eac3ef4 -> 9873d57f2


[SPARK-17490][SQL] Optimize SerializeFromObject() for a primitive array

Waiting for merging #13680

This PR optimizes `SerializeFromObject()` for an primitive array. This is derived from #13758
to address one of problems by using a simple way in #13758.

The current implementation always generates `GenericArrayData` from `SerializeFromObject()`
for any type of an array in a logical plan. This involves a boxing at a constructor of `GenericArrayData`
when `SerializedFromObject()` has an primitive array.

This PR enables to generate `UnsafeArrayData` from `SerializeFromObject()` for a primitive
array. It can avoid boxing to create an instance of `ArrayData` in the generated code by Catalyst.

This PR also generate `UnsafeArrayData` in a case for `RowEncoder.serializeFor` or `CatalystTypeConverters.createToCatalystConverter`.

Performance improvement of `SerializeFromObject()` is up to 2.0x

```
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without this PR
Write an array in Dataset:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            556 /  608         15.1          66.3     
 1.0X
Double                                        1668 / 1746          5.0         198.8     
 0.3X

with this PR
Write an array in Dataset:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            352 /  401         23.8          42.0     
 1.0X
Double                                         821 /  885         10.2          97.9     
 0.4X
```

Here is an example program that will happen in mllib as described in [SPARK-16070](https://issues.apache.org/jira/browse/SPARK-16070).

```
sparkContext.parallelize(Seq(Array(1, 2)), 1).toDS.map(e => e).show
```

Generated code before applying this PR

``` java
/* 039 */   protected void processNext() throws java.io.IOException {
/* 040 */     while (inputadapter_input.hasNext()) {
/* 041 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 042 */       int[] inputadapter_value = (int[])inputadapter_row.get(0, null);
/* 043 */
/* 044 */       Object mapelements_obj = ((Expression) references[0]).eval(null);
/* 045 */       scala.Function1 mapelements_value1 = (scala.Function1) mapelements_obj;
/* 046 */
/* 047 */       boolean mapelements_isNull = false || false;
/* 048 */       int[] mapelements_value = null;
/* 049 */       if (!mapelements_isNull) {
/* 050 */         Object mapelements_funcResult = null;
/* 051 */         mapelements_funcResult = mapelements_value1.apply(inputadapter_value);
/* 052 */         if (mapelements_funcResult == null) {
/* 053 */           mapelements_isNull = true;
/* 054 */         } else {
/* 055 */           mapelements_value = (int[]) mapelements_funcResult;
/* 056 */         }
/* 057 */
/* 058 */       }
/* 059 */       mapelements_isNull = mapelements_value == null;
/* 060 */
/* 061 */       serializefromobject_argIsNulls[0] = mapelements_isNull;
/* 062 */       serializefromobject_argValue = mapelements_value;
/* 063 */
/* 064 */       boolean serializefromobject_isNull = false;
/* 065 */       for (int idx = 0; idx < 1; idx++) {
/* 066 */         if (serializefromobject_argIsNulls[idx]) { serializefromobject_isNull =
true; break; }
/* 067 */       }
/* 068 */
/* 069 */       final ArrayData serializefromobject_value = serializefromobject_isNull ? null
: new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 070 */       serializefromobject_holder.reset();
/* 071 */
/* 072 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 073 */
/* 074 */       if (serializefromobject_isNull) {
/* 075 */         serializefromobject_rowWriter.setNullAt(0);
/* 076 */       } else {
/* 077 */         // Remember the current cursor so that we can calculate how many bytes are
/* 078 */         // written later.
/* 079 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 080 */
/* 081 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 082 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 083 */           // grow the global buffer before writing data.
/* 084 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 085 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer,
serializefromobject_holder.cursor);
/* 086 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 087 */
/* 088 */         } else {
/* 089 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 090 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder,
serializefromobject_numElements, 4);
/* 091 */
/* 092 */           for (int serializefromobject_index = 0; serializefromobject_index <
serializefromobject_numElements; serializefromobject_index++) {
/* 093 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 094 */               serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 095 */             } else {
/* 096 */               final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 097 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 098 */             }
/* 099 */           }
/* 100 */         }
/* 101 */
/* 102 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor,
serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 103 */       }
/* 104 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 105 */       append(serializefromobject_result);
/* 106 */       if (shouldStop()) return;
/* 107 */     }
/* 108 */   }
/* 109 */ }
```

Generated code after applying this PR

``` java
/* 035 */   protected void processNext() throws java.io.IOException {
/* 036 */     while (inputadapter_input.hasNext()) {
/* 037 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 038 */       int[] inputadapter_value = (int[])inputadapter_row.get(0, null);
/* 039 */
/* 040 */       Object mapelements_obj = ((Expression) references[0]).eval(null);
/* 041 */       scala.Function1 mapelements_value1 = (scala.Function1) mapelements_obj;
/* 042 */
/* 043 */       boolean mapelements_isNull = false || false;
/* 044 */       int[] mapelements_value = null;
/* 045 */       if (!mapelements_isNull) {
/* 046 */         Object mapelements_funcResult = null;
/* 047 */         mapelements_funcResult = mapelements_value1.apply(inputadapter_value);
/* 048 */         if (mapelements_funcResult == null) {
/* 049 */           mapelements_isNull = true;
/* 050 */         } else {
/* 051 */           mapelements_value = (int[]) mapelements_funcResult;
/* 052 */         }
/* 053 */
/* 054 */       }
/* 055 */       mapelements_isNull = mapelements_value == null;
/* 056 */
/* 057 */       boolean serializefromobject_isNull = mapelements_isNull;
/* 058 */       final ArrayData serializefromobject_value = serializefromobject_isNull ? null
: org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(mapelements_value);
/* 059 */       serializefromobject_isNull = serializefromobject_value == null;
/* 060 */       serializefromobject_holder.reset();
/* 061 */
/* 062 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 063 */
/* 064 */       if (serializefromobject_isNull) {
/* 065 */         serializefromobject_rowWriter.setNullAt(0);
/* 066 */       } else {
/* 067 */         // Remember the current cursor so that we can calculate how many bytes are
/* 068 */         // written later.
/* 069 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 070 */
/* 071 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 072 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 073 */           // grow the global buffer before writing data.
/* 074 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 075 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer,
serializefromobject_holder.cursor);
/* 076 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 077 */
/* 078 */         } else {
/* 079 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 080 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder,
serializefromobject_numElements, 4);
/* 081 */
/* 082 */           for (int serializefromobject_index = 0; serializefromobject_index <
serializefromobject_numElements; serializefromobject_index++) {
/* 083 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 084 */               serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 085 */             } else {
/* 086 */               final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 087 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 088 */             }
/* 089 */           }
/* 090 */         }
/* 091 */
/* 092 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor,
serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 093 */       }
/* 094 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 095 */       append(serializefromobject_result);
/* 096 */       if (shouldStop()) return;
/* 097 */     }
/* 098 */   }
/* 099 */ }
```

Added a test in `DatasetSuite`, `RowEncoderSuite`, and `CatalystTypeConvertersSuite`

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #15044 from kiszk/SPARK-17490.

(cherry picked from commit 19cf208063f035d793d2306295a251a9af7e32f6)
Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9873d57f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9873d57f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9873d57f

Branch: refs/heads/branch-2.1
Commit: 9873d57f2c76d1a6995c4ff5a45be1259a7948f0
Parents: d1eac3e
Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Authored: Tue Nov 8 00:14:57 2016 +0100
Committer: Herman van Hovell <hvanhovell@databricks.com>
Committed: Tue Nov 8 00:17:05 2016 +0100

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 16 ++++
 .../sql/catalyst/encoders/RowEncoder.scala      | 27 +++----
 .../spark/sql/catalyst/util/ArrayData.scala     | 15 +++-
 .../catalyst/CatalystTypeConvertersSuite.scala  | 33 ++++++++
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 26 +++++++
 .../org/apache/spark/sql/DatasetSuite.scala     | 18 +++++
 .../benchmark/PrimitiveArrayBenchmark.scala     | 82 ++++++++++++++++++++
 7 files changed, 203 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 31c6e5d..7bcaea7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -441,6 +441,22 @@ object ScalaReflection extends ScalaReflection {
           val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
           MapObjects(serializerFor(_, elementType, newPath), input, dt)
 
+         case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
+                    FloatType | DoubleType) =>
+          val cls = input.dataType.asInstanceOf[ObjectType].cls
+          if (cls.isArray && cls.getComponentType.isPrimitive) {
+            StaticInvoke(
+              classOf[UnsafeArrayData],
+              ArrayType(dt, false),
+              "fromPrimitiveArray",
+              input :: Nil)
+          } else {
+            NewInstance(
+              classOf[GenericArrayData],
+              input :: Nil,
+              dataType = ArrayType(dt, schemaFor(elementType).nullable))
+          }
+
         case dt =>
           NewInstance(
             classOf[GenericArrayData],

http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 2a6fcd0..e95e97b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.SparkException
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.expressions.objects._
@@ -119,18 +119,19 @@ object RowEncoder {
         "fromString",
         inputObject :: Nil)
 
-    case t @ ArrayType(et, _) => et match {
-      case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType
=>
-        // TODO: validate input type for primitive array.
-        NewInstance(
-          classOf[GenericArrayData],
-          inputObject :: Nil,
-          dataType = t)
-      case _ => MapObjects(
-        element => serializerFor(ValidateExternalType(element, et), et),
-        inputObject,
-        ObjectType(classOf[Object]))
-    }
+    case t @ ArrayType(et, cn) =>
+      et match {
+        case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType
=>
+          StaticInvoke(
+            classOf[ArrayData],
+            t,
+            "toArrayData",
+            inputObject :: Nil)
+        case _ => MapObjects(
+          element => serializerFor(ValidateExternalType(element, et), et),
+          inputObject,
+          ObjectType(classOf[Object]))
+      }
 
     case t @ MapType(kt, vt, valueNullable) =>
       val keys =

http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index cad4a08..140e86d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
 import org.apache.spark.sql.types.DataType
 
+object ArrayData {
+  def toArrayData(input: Any): ArrayData = input match {
+    case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a)
+    case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a)
+    case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a)
+    case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a)
+    case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a)
+    case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a)
+    case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
+    case other => new GenericArrayData(other)
+  }
+}
+
 abstract class ArrayData extends SpecializedGetters with Serializable {
   def numElements(): Int
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 03bb102..f3702ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
 
 class CatalystTypeConvertersSuite extends SparkFunSuite {
@@ -61,4 +63,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite {
   test("option handling in createToCatalystConverter") {
     assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
   }
+
+  test("primitive array handling") {
+    val intArray = Array(1, 100, 10000)
+    val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
+    val intArrayType = ArrayType(IntegerType, false)
+    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) ===
intArray)
+
+    val doubleArray = Array(1.1, 111.1, 11111.1)
+    val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
+    val doubleArrayType = ArrayType(DoubleType, false)
+    assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray)
+      === doubleArray)
+  }
+
+  test("An array with null handling") {
+    val intArray = Array(1, null, 100, null, 10000)
+    val intGenericArray = new GenericArrayData(intArray)
+    val intArrayType = ArrayType(IntegerType, true)
+    assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray)
+      === intArray)
+    assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray)
+      == intGenericArray)
+
+    val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
+    val doubleGenericArray = new GenericArrayData(doubleArray)
+    val doubleArrayType = ArrayType(DoubleType, true)
+    assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray)
+      === doubleArray)
+    assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray)
+      == doubleGenericArray)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 2e513ea..1a5569a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(encoder.serializer.head.nullable == false)
   }
 
+  test("RowEncoder should support primitive arrays") {
+    val schema = new StructType()
+      .add("booleanPrimitiveArray", ArrayType(BooleanType, false))
+      .add("bytePrimitiveArray", ArrayType(ByteType, false))
+      .add("shortPrimitiveArray", ArrayType(ShortType, false))
+      .add("intPrimitiveArray", ArrayType(IntegerType, false))
+      .add("longPrimitiveArray", ArrayType(LongType, false))
+      .add("floatPrimitiveArray", ArrayType(FloatType, false))
+      .add("doublePrimitiveArray", ArrayType(DoubleType, false))
+    val encoder = RowEncoder(schema).resolveAndBind()
+    val input = Seq(
+      Array(true, false),
+      Array(1.toByte, 64.toByte, Byte.MaxValue),
+      Array(1.toShort, 255.toShort, Short.MaxValue),
+      Array(1, 10000, Int.MaxValue),
+      Array(1.toLong, 1000000.toLong, Long.MaxValue),
+      Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue),
+      Array(11.1111, 123456.7890123, Double.MaxValue)
+    )
+    val row = encoder.toRow(Row.fromSeq(input))
+    val convertedBack = encoder.fromRow(row)
+    input.zipWithIndex.map { case (array, index) =>
+      assert(convertedBack.getSeq(index) === array)
+    }
+  }
+
   test("RowEncoder should support array as the external type for ArrayType") {
     val schema = new StructType()
       .add("array", ArrayType(IntegerType))

http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a8dd422..81fa8cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1033,6 +1033,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
     }
   }
+
+  test("identity map for primitive arrays") {
+    val arrayByte = Array(1.toByte, 2.toByte, 3.toByte)
+    val arrayInt = Array(1, 2, 3)
+    val arrayLong = Array(1.toLong, 2.toLong, 3.toLong)
+    val arrayDouble = Array(1.1, 2.2, 3.3)
+    val arrayString = Array("a", "b", "c")
+    val dsByte = sparkContext.parallelize(Seq(arrayByte), 1).toDS.map(e => e)
+    val dsInt = sparkContext.parallelize(Seq(arrayInt), 1).toDS.map(e => e)
+    val dsLong = sparkContext.parallelize(Seq(arrayLong), 1).toDS.map(e => e)
+    val dsDouble = sparkContext.parallelize(Seq(arrayDouble), 1).toDS.map(e => e)
+    val dsString = sparkContext.parallelize(Seq(arrayString), 1).toDS.map(e => e)
+    checkDataset(dsByte, arrayByte)
+    checkDataset(dsInt, arrayInt)
+    checkDataset(dsLong, arrayLong)
+    checkDataset(dsDouble, arrayDouble)
+    checkDataset(dsString, arrayString)
+  }
 }
 
 case class Generic[T](id: T, value: Double)

http://git-wip-us.apache.org/repos/asf/spark/blob/9873d57f/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala
new file mode 100644
index 0000000..e7c8f27
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array
+ * To run this:
+ *  1. replace ignore(...) with test(...)
+ *  2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark"
+ *
+ * Benchmarks in this file are skipped in normal builds.
+ */
+class PrimitiveArrayBenchmark extends BenchmarkBase {
+
+  def writeDatasetArray(iters: Int): Unit = {
+    import sparkSession.implicits._
+
+    val count = 1024 * 1024 * 2
+
+    val sc = sparkSession.sparkContext
+    val primitiveIntArray = Array.fill[Int](count)(65535)
+    val dsInt = sc.parallelize(Seq(primitiveIntArray), 1).toDS
+    dsInt.count  // force to build dataset
+    val intArray = { i: Int =>
+      var n = 0
+      var len = 0
+      while (n < iters) {
+        len += dsInt.map(e => e).queryExecution.toRdd.collect.length
+        n += 1
+      }
+    }
+    val primitiveDoubleArray = Array.fill[Double](count)(65535.0)
+    val dsDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDS
+    dsDouble.count  // force to build dataset
+    val doubleArray = { i: Int =>
+      var n = 0
+      var len = 0
+      while (n < iters) {
+        len += dsDouble.map(e => e).queryExecution.toRdd.collect.length
+        n += 1
+      }
+    }
+
+    val benchmark = new Benchmark("Write an array in Dataset", count * iters)
+    benchmark.addCase("Int   ")(intArray)
+    benchmark.addCase("Double")(doubleArray)
+    benchmark.run
+    /*
+    OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
+    Intel Xeon E3-12xx v2 (Ivy Bridge)
+    Write an array in Dataset:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
+    ------------------------------------------------------------------------------------------------
+    Int                                            352 /  401         23.8          42.0
      1.0X
+    Double                                         821 /  885         10.2          97.9
      0.4X
+    */
+  }
+
+  ignore("Write an array in Dataset") {
+    writeDatasetArray(4)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message