spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-21440][SQL][PYSPARK] Refactor ArrowConverters and add ArrayType and StructType support.
Date Thu, 27 Jul 2017 11:20:00 GMT
Repository: spark
Updated Branches:
  refs/heads/master ebbe589d1 -> 2ff35a057


[SPARK-21440][SQL][PYSPARK] Refactor ArrowConverters and add ArrayType and StructType support.

## What changes were proposed in this pull request?

This is a refactoring of `ArrowConverters` and related classes.

1. Refactor `ColumnWriter` as `ArrowWriter`.
2. Add `ArrayType` and `StructType` support.
3. Refactor `ArrowConverters` to skip intermediate `ArrowRecordBatch` creation.

## How was this patch tested?

Added some tests and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #18655 from ueshin/issues/SPARK-21440.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ff35a05
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ff35a05
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ff35a05

Branch: refs/heads/master
Commit: 2ff35a057efd36bd5c8a545a1ec3bc341432a904
Parents: ebbe589
Author: Takuya UESHIN <ueshin@databricks.com>
Authored: Thu Jul 27 19:19:51 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu Jul 27 19:19:51 2017 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     |   4 +-
 .../scala/org/apache/spark/sql/Dataset.scala    |   4 +-
 .../sql/execution/arrow/ArrowConverters.scala   | 351 ++-------------
 .../spark/sql/execution/arrow/ArrowWriter.scala | 323 ++++++++++++++
 .../execution/arrow/ArrowConvertersSuite.scala  | 447 ++++++++++++++++++-
 .../sql/execution/arrow/ArrowWriterSuite.scala  | 260 +++++++++++
 6 files changed, 1074 insertions(+), 315 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1c1a0ca..54756ed 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3018,8 +3018,8 @@ class ArrowTests(ReusedPySparkTestCase):
         self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
 
     def test_unsupported_datatype(self):
-        schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
-        df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
+        schema = StructType([StructField("dt", DateType(), True)])
+        df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
         with QuietTest(self.sc):
             self.assertRaises(Exception, lambda: df.toPandas())
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 71ab0dd..9007367 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -27,6 +27,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
+import org.apache.spark.TaskContext
 import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.function._
@@ -3090,7 +3091,8 @@ class Dataset[T] private[sql](
     val schemaCaptured = this.schema
     val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
     queryExecution.toRdd.mapPartitionsInternal { iter =>
-      ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch)
+      val context = TaskContext.get()
+      ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index c913efe..240f38f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -20,18 +20,13 @@ package org.apache.spark.sql.execution.arrow
 import java.io.ByteArrayOutputStream
 import java.nio.channels.Channels
 
-import scala.collection.JavaConverters._
-
-import io.netty.buffer.ArrowBuf
-import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.memory.BufferAllocator
 import org.apache.arrow.vector._
-import org.apache.arrow.vector.BaseValueVector.BaseMutator
 import org.apache.arrow.vector.file._
-import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
-import org.apache.arrow.vector.types.FloatingPointPrecision
-import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.arrow.vector.schema.ArrowRecordBatch
 import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
 
+import org.apache.spark.TaskContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -55,19 +50,6 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se
   def asPythonSerializable: Array[Byte] = payload
 }
 
-private[sql] object ArrowPayload {
-
-  /**
-   * Create an ArrowPayload from an ArrowRecordBatch and Spark schema.
-   */
-  def apply(
-      batch: ArrowRecordBatch,
-      schema: StructType,
-      allocator: BufferAllocator): ArrowPayload = {
-    new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator))
-  }
-}
-
 private[sql] object ArrowConverters {
 
   /**
@@ -77,95 +59,55 @@ private[sql] object ArrowConverters {
   private[sql] def toPayloadIterator(
       rowIter: Iterator[InternalRow],
       schema: StructType,
-      maxRecordsPerBatch: Int): Iterator[ArrowPayload] = {
-    new Iterator[ArrowPayload] {
-      private val _allocator = new RootAllocator(Long.MaxValue)
-      private var _nextPayload = if (rowIter.nonEmpty) convert() else null
+      maxRecordsPerBatch: Int,
+      context: TaskContext): Iterator[ArrowPayload] = {
 
-      override def hasNext: Boolean = _nextPayload != null
-
-      override def next(): ArrowPayload = {
-        val obj = _nextPayload
-        if (hasNext) {
-          if (rowIter.hasNext) {
-            _nextPayload = convert()
-          } else {
-            _allocator.close()
-            _nextPayload = null
-          }
-        }
-        obj
-      }
-
-      private def convert(): ArrowPayload = {
-        val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch)
-        ArrowPayload(batch, schema, _allocator)
-      }
-    }
-  }
+    val arrowSchema = ArrowUtils.toArrowSchema(schema)
+    val allocator =
+      ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)
 
-  /**
-   * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed
-   * or the number of records in the batch equals maxRecordsInBatch.  If maxRecordsPerBatch is 0,
-   * then rowIter will be fully consumed.
-   */
-  private def internalRowIterToArrowBatch(
-      rowIter: Iterator[InternalRow],
-      schema: StructType,
-      allocator: BufferAllocator,
-      maxRecordsPerBatch: Int = 0): ArrowRecordBatch = {
+    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    val arrowWriter = ArrowWriter.create(root)
 
-    val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) =>
-      ColumnWriter(field.dataType, ordinal, allocator).init()
-    }
+    var closed = false
 
-    val writerLength = columnWriters.length
-    var recordsInBatch = 0
-    while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) {
-      val row = rowIter.next()
-      var i = 0
-      while (i < writerLength) {
-        columnWriters(i).write(row)
-        i += 1
+    context.addTaskCompletionListener { _ =>
+      if (!closed) {
+        root.close()
+        allocator.close()
       }
-      recordsInBatch += 1
     }
 
-    val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip
-    val buffers = bufferArrays.flatten
-
-    val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0
-    val recordBatch = new ArrowRecordBatch(rowLength,
-      fieldNodes.toList.asJava, buffers.toList.asJava)
+    new Iterator[ArrowPayload] {
 
-    buffers.foreach(_.release())
-    recordBatch
-  }
+      override def hasNext: Boolean = rowIter.hasNext || {
+        root.close()
+        allocator.close()
+        closed = true
+        false
+      }
 
-  /**
-   * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed,
-   * the batch can no longer be used.
-   */
-  private[arrow] def batchToByteArray(
-      batch: ArrowRecordBatch,
-      schema: StructType,
-      allocator: BufferAllocator): Array[Byte] = {
-    val arrowSchema = ArrowUtils.toArrowSchema(schema)
-    val root = VectorSchemaRoot.create(arrowSchema, allocator)
-    val out = new ByteArrayOutputStream()
-    val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
+      override def next(): ArrowPayload = {
+        val out = new ByteArrayOutputStream()
+        val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
+
+        Utils.tryWithSafeFinally {
+          var rowCount = 0
+          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
+            val row = rowIter.next()
+            arrowWriter.write(row)
+            rowCount += 1
+          }
+          arrowWriter.finish()
+          writer.writeBatch()
+        } {
+          arrowWriter.reset()
+          writer.close()
+        }
 
-    // Write a batch to byte stream, ensure the batch, allocator and writer are closed
-    Utils.tryWithSafeFinally {
-      val loader = new VectorLoader(root)
-      loader.load(batch)
-      writer.writeBatch()  // writeBatch can throw IOException
-    } {
-      batch.close()
-      root.close()
-      writer.close()
+        new ArrowPayload(out.toByteArray)
+      }
     }
-    out.toByteArray
   }
 
   /**
@@ -188,214 +130,3 @@ private[sql] object ArrowConverters {
     }
   }
 }
-
-/**
- * Interface for writing InternalRows to Arrow Buffers.
- */
-private[arrow] trait ColumnWriter {
-  def init(): this.type
-  def write(row: InternalRow): Unit
-
-  /**
-   * Clear the column writer and return the ArrowFieldNode and ArrowBuf.
-   * This should be called only once after all the data is written.
-   */
-  def finish(): (ArrowFieldNode, Array[ArrowBuf])
-}
-
-/**
- * Base class for flat arrow column writer, i.e., column without children.
- */
-private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int)
-  extends ColumnWriter {
-
-  def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype)
-
-  def valueVector: BaseDataValueVector
-  def valueMutator: BaseMutator
-
-  def setNull(): Unit
-  def setValue(row: InternalRow): Unit
-
-  protected var count = 0
-  protected var nullCount = 0
-
-  override def init(): this.type = {
-    valueVector.allocateNew()
-    this
-  }
-
-  override def write(row: InternalRow): Unit = {
-    if (row.isNullAt(ordinal)) {
-      setNull()
-      nullCount += 1
-    } else {
-      setValue(row)
-    }
-    count += 1
-  }
-
-  override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = {
-    valueMutator.setValueCount(count)
-    val fieldNode = new ArrowFieldNode(count, nullCount)
-    val valueBuffers = valueVector.getBuffers(true)
-    (fieldNode, valueBuffers)
-  }
-}
-
-private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableBitVector
-    = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 )
-}
-
-private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableSmallIntVector
-    = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator)
-  override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, row.getShort(ordinal))
-}
-
-private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableIntVector
-    = new NullableIntVector("IntValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, row.getInt(ordinal))
-}
-
-private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableBigIntVector
-    = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, row.getLong(ordinal))
-}
-
-private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableFloat4Vector
-    = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, row.getFloat(ordinal))
-}
-
-private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableFloat8Vector
-    = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, row.getDouble(ordinal))
-}
-
-private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableUInt1Vector
-    = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit
-    = valueMutator.setSafe(count, row.getByte(ordinal))
-}
-
-private[arrow] class UTF8StringColumnWriter(
-    dtype: ArrowType,
-    ordinal: Int,
-    allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableVarCharVector
-    = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit = {
-    val str = row.getUTF8String(ordinal)
-    valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes)
-  }
-}
-
-private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableVarBinaryVector
-    = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit = {
-    val bytes = row.getBinary(ordinal)
-    valueMutator.setSafe(count, bytes, 0, bytes.length)
-  }
-}
-
-private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableDateDayVector
-    = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit = {
-    valueMutator.setSafe(count, row.getInt(ordinal))
-  }
-}
-
-private[arrow] class TimeStampColumnWriter(
-    dtype: ArrowType,
-    ordinal: Int,
-    allocator: BufferAllocator)
-  extends PrimitiveColumnWriter(ordinal) {
-  override val valueVector: NullableTimeStampMicroVector
-    = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator)
-  override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator
-
-  override def setNull(): Unit = valueMutator.setNull(count)
-  override def setValue(row: InternalRow): Unit = {
-    valueMutator.setSafe(count, row.getLong(ordinal))
-  }
-}
-
-private[arrow] object ColumnWriter {
-
-  /**
-   * Create an Arrow ColumnWriter given the type and ordinal of row.
-   */
-  def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = {
-    val dtype = ArrowUtils.toArrowType(dataType)
-    dataType match {
-      case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator)
-      case ShortType => new ShortColumnWriter(dtype, ordinal, allocator)
-      case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator)
-      case LongType => new LongColumnWriter(dtype, ordinal, allocator)
-      case FloatType => new FloatColumnWriter(dtype, ordinal, allocator)
-      case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator)
-      case ByteType => new ByteColumnWriter(dtype, ordinal, allocator)
-      case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator)
-      case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator)
-      case DateType => new DateColumnWriter(dtype, ordinal, allocator)
-      case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator)
-      case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
new file mode 100644
index 0000000..11ba04d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.arrow.vector.util.DecimalUtility
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.types._
+
+object ArrowWriter {
+
+  def create(schema: StructType): ArrowWriter = {
+    val arrowSchema = ArrowUtils.toArrowSchema(schema)
+    val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
+    create(root)
+  }
+
+  def create(root: VectorSchemaRoot): ArrowWriter = {
+    val children = root.getFieldVectors().asScala.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+    val field = vector.getField()
+    (ArrowUtils.fromArrowField(field), vector) match {
+      case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector)
+      case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector)
+      case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector)
+      case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector)
+      case (LongType, vector: NullableBigIntVector) => new LongWriter(vector)
+      case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector)
+      case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector)
+      case (StringType, vector: NullableVarCharVector) => new StringWriter(vector)
+      case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector)
+      case (ArrayType(_, _), vector: ListVector) =>
+        val elementVector = createFieldWriter(vector.getDataVector())
+        new ArrayWriter(vector, elementVector)
+      case (StructType(_), vector: NullableMapVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new StructWriter(vector, children.toArray)
+      case (dt, _) =>
+        throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}")
+    }
+  }
+}
+
+class ArrowWriter(
+    val root: VectorSchemaRoot,
+    fields: Array[ArrowFieldWriter]) {
+
+  def schema: StructType = StructType(fields.map { f =>
+    StructField(f.name, f.dataType, f.nullable)
+  })
+
+  private var count: Int = 0
+
+  def write(row: InternalRow): Unit = {
+    var i = 0
+    while (i < fields.size) {
+      fields(i).write(row, i)
+      i += 1
+    }
+    count += 1
+  }
+
+  def finish(): Unit = {
+    root.setRowCount(count)
+    fields.foreach(_.finish())
+  }
+
+  def reset(): Unit = {
+    root.setRowCount(0)
+    count = 0
+    fields.foreach(_.reset())
+  }
+}
+
+private[arrow] abstract class ArrowFieldWriter {
+
+  def valueVector: ValueVector
+  def valueMutator: ValueVector.Mutator
+
+  def name: String = valueVector.getField().getName()
+  def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField())
+  def nullable: Boolean = valueVector.getField().isNullable()
+
+  def setNull(): Unit
+  def setValue(input: SpecializedGetters, ordinal: Int): Unit
+
+  private[arrow] var count: Int = 0
+
+  def write(input: SpecializedGetters, ordinal: Int): Unit = {
+    if (input.isNullAt(ordinal)) {
+      setNull()
+    } else {
+      setValue(input, ordinal)
+    }
+    count += 1
+  }
+
+  def finish(): Unit = {
+    valueMutator.setValueCount(count)
+  }
+
+  def reset(): Unit = {
+    valueMutator.reset()
+    count = 0
+  }
+}
+
+private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
+  }
+}
+
+private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, input.getByte(ordinal))
+  }
+}
+
+private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, input.getShort(ordinal))
+  }
+}
+
+private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, input.getInt(ordinal))
+  }
+}
+
+private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, input.getLong(ordinal))
+  }
+}
+
+private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, input.getFloat(ordinal))
+  }
+}
+
+private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueMutator.setSafe(count, input.getDouble(ordinal))
+  }
+}
+
+private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val utf8 = input.getUTF8String(ordinal)
+    // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+    valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes())
+  }
+}
+
+private[arrow] class BinaryWriter(
+    val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val bytes = input.getBinary(ordinal)
+    valueMutator.setSafe(count, bytes, 0, bytes.length)
+  }
+}
+
+private[arrow] class ArrayWriter(
+    val valueVector: ListVector,
+    val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
+
+  override def valueMutator: ListVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val array = input.getArray(ordinal)
+    var i = 0
+    valueMutator.startNewValue(count)
+    while (i < array.numElements()) {
+      elementWriter.write(array, i)
+      i += 1
+    }
+    valueMutator.endValue(count, array.numElements())
+  }
+
+  override def finish(): Unit = {
+    super.finish()
+    elementWriter.finish()
+  }
+
+  override def reset(): Unit = {
+    super.reset()
+    elementWriter.reset()
+  }
+}
+
+private[arrow] class StructWriter(
+    val valueVector: NullableMapVector,
+    children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {
+
+  override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator()
+
+  override def setNull(): Unit = {
+    var i = 0
+    while (i < children.length) {
+      children(i).setNull()
+      children(i).count += 1
+      i += 1
+    }
+    valueMutator.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val struct = input.getStruct(ordinal, children.length)
+    var i = 0
+    while (i < struct.numFields) {
+      children(i).write(struct, i)
+      i += 1
+    }
+    valueMutator.setIndexDefined(count)
+  }
+
+  override def finish(): Unit = {
+    super.finish()
+    children.foreach(_.finish())
+  }
+
+  override def reset(): Unit = {
+    super.reset()
+    children.foreach(_.reset())
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 55b4655..4893b52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
 import org.apache.spark.util.Utils
 
 
@@ -857,6 +857,449 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
     collectAndValidate(df, json, "nanData-floating_point.json")
   }
 
+  test("array type conversion") {
+    val json =
+      s"""
+         |{
+         |  "schema" : {
+         |    "fields" : [ {
+         |      "name" : "a_arr",
+         |      "nullable" : true,
+         |      "type" : {
+         |        "name" : "list"
+         |      },
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "nullable" : false,
+         |        "type" : {
+         |          "name" : "int",
+         |          "bitWidth" : 32,
+         |          "isSigned" : true
+         |        },
+         |        "children" : [ ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "DATA",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        }, {
+         |          "type" : "OFFSET",
+         |          "typeBitWidth" : 32
+         |        } ]
+         |      }
+         |    }, {
+         |      "name" : "b_arr",
+         |      "nullable" : true,
+         |      "type" : {
+         |        "name" : "list"
+         |      },
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "nullable" : false,
+         |        "type" : {
+         |          "name" : "int",
+         |          "bitWidth" : 32,
+         |          "isSigned" : true
+         |        },
+         |        "children" : [ ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "DATA",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        }, {
+         |          "type" : "OFFSET",
+         |          "typeBitWidth" : 32
+         |        } ]
+         |      }
+         |    }, {
+         |      "name" : "c_arr",
+         |      "nullable" : true,
+         |      "type" : {
+         |        "name" : "list"
+         |      },
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "nullable" : true,
+         |        "type" : {
+         |          "name" : "int",
+         |          "bitWidth" : 32,
+         |          "isSigned" : true
+         |        },
+         |        "children" : [ ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "DATA",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        }, {
+         |          "type" : "OFFSET",
+         |          "typeBitWidth" : 32
+         |        } ]
+         |      }
+         |    }, {
+         |      "name" : "d_arr",
+         |      "nullable" : true,
+         |      "type" : {
+         |        "name" : "list"
+         |      },
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "nullable" : true,
+         |        "type" : {
+         |          "name" : "list"
+         |        },
+         |        "children" : [ {
+         |          "name" : "element",
+         |          "nullable" : false,
+         |          "type" : {
+         |            "name" : "int",
+         |            "bitWidth" : 32,
+         |            "isSigned" : true
+         |          },
+         |          "children" : [ ],
+         |          "typeLayout" : {
+         |            "vectors" : [ {
+         |              "type" : "VALIDITY",
+         |              "typeBitWidth" : 1
+         |            }, {
+         |              "type" : "DATA",
+         |              "typeBitWidth" : 32
+         |            } ]
+         |          }
+         |        } ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "OFFSET",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        }, {
+         |          "type" : "OFFSET",
+         |          "typeBitWidth" : 32
+         |        } ]
+         |      }
+         |    } ]
+         |  },
+         |  "batches" : [ {
+         |    "count" : 4,
+         |    "columns" : [ {
+         |      "name" : "a_arr",
+         |      "count" : 4,
+         |      "VALIDITY" : [ 1, 1, 1, 1 ],
+         |      "OFFSET" : [ 0, 2, 4, 4, 5 ],
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "count" : 5,
+         |        "VALIDITY" : [ 1, 1, 1, 1, 1 ],
+         |        "DATA" : [ 1, 2, 3, 4, 5 ]
+         |      } ]
+         |    }, {
+         |      "name" : "b_arr",
+         |      "count" : 4,
+         |      "VALIDITY" : [ 1, 0, 1, 0 ],
+         |      "OFFSET" : [ 0, 2, 2, 2, 2 ],
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "count" : 2,
+         |        "VALIDITY" : [ 1, 1 ],
+         |        "DATA" : [ 1, 2 ]
+         |      } ]
+         |    }, {
+         |      "name" : "c_arr",
+         |      "count" : 4,
+         |      "VALIDITY" : [ 1, 1, 1, 1 ],
+         |      "OFFSET" : [ 0, 2, 4, 4, 5 ],
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "count" : 5,
+         |        "VALIDITY" : [ 1, 1, 1, 0, 1 ],
+         |        "DATA" : [ 1, 2, 3, 0, 5 ]
+         |      } ]
+         |    }, {
+         |      "name" : "d_arr",
+         |      "count" : 4,
+         |      "VALIDITY" : [ 1, 1, 1, 1 ],
+         |      "OFFSET" : [ 0, 1, 3, 3, 4 ],
+         |      "children" : [ {
+         |        "name" : "element",
+         |        "count" : 4,
+         |        "VALIDITY" : [ 1, 1, 1, 1 ],
+         |        "OFFSET" : [ 0, 2, 3, 3, 4 ],
+         |        "children" : [ {
+         |          "name" : "element",
+         |          "count" : 4,
+         |          "VALIDITY" : [ 1, 1, 1, 1 ],
+         |          "DATA" : [ 1, 2, 3, 5 ]
+         |        } ]
+         |      } ]
+         |    } ]
+         |  } ]
+         |}
+       """.stripMargin
+
+    val aArr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5))
+    val bArr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None)
+    val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5)))
+    val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5)))
+
+    val df = aArr.zip(bArr).zip(cArr).zip(dArr).map {
+      case (((a, b), c), d) => (a, b, c, d)
+    }.toDF("a_arr", "b_arr", "c_arr", "d_arr")
+
+    collectAndValidate(df, json, "arrayData.json")
+  }
+
+  test("struct type conversion") {
+    val json =
+      s"""
+         |{
+         |  "schema" : {
+         |    "fields" : [ {
+         |      "name" : "a_struct",
+         |      "nullable" : false,
+         |      "type" : {
+         |        "name" : "struct"
+         |      },
+         |      "children" : [ {
+         |        "name" : "i",
+         |        "nullable" : false,
+         |        "type" : {
+         |          "name" : "int",
+         |          "bitWidth" : 32,
+         |          "isSigned" : true
+         |        },
+         |        "children" : [ ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "DATA",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        } ]
+         |      }
+         |    }, {
+         |      "name" : "b_struct",
+         |      "nullable" : true,
+         |      "type" : {
+         |        "name" : "struct"
+         |      },
+         |      "children" : [ {
+         |        "name" : "i",
+         |        "nullable" : false,
+         |        "type" : {
+         |          "name" : "int",
+         |          "bitWidth" : 32,
+         |          "isSigned" : true
+         |        },
+         |        "children" : [ ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "DATA",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        } ]
+         |      }
+         |    }, {
+         |      "name" : "c_struct",
+         |      "nullable" : false,
+         |      "type" : {
+         |        "name" : "struct"
+         |      },
+         |      "children" : [ {
+         |        "name" : "i",
+         |        "nullable" : true,
+         |        "type" : {
+         |          "name" : "int",
+         |          "bitWidth" : 32,
+         |          "isSigned" : true
+         |        },
+         |        "children" : [ ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          }, {
+         |            "type" : "DATA",
+         |            "typeBitWidth" : 32
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        } ]
+         |      }
+         |    }, {
+         |      "name" : "d_struct",
+         |      "nullable" : true,
+         |      "type" : {
+         |        "name" : "struct"
+         |      },
+         |      "children" : [ {
+         |        "name" : "nested",
+         |        "nullable" : true,
+         |        "type" : {
+         |          "name" : "struct"
+         |        },
+         |        "children" : [ {
+         |          "name" : "i",
+         |          "nullable" : true,
+         |          "type" : {
+         |            "name" : "int",
+         |            "bitWidth" : 32,
+         |            "isSigned" : true
+         |          },
+         |          "children" : [ ],
+         |          "typeLayout" : {
+         |            "vectors" : [ {
+         |              "type" : "VALIDITY",
+         |              "typeBitWidth" : 1
+         |            }, {
+         |              "type" : "DATA",
+         |              "typeBitWidth" : 32
+         |            } ]
+         |          }
+         |        } ],
+         |        "typeLayout" : {
+         |          "vectors" : [ {
+         |            "type" : "VALIDITY",
+         |            "typeBitWidth" : 1
+         |          } ]
+         |        }
+         |      } ],
+         |      "typeLayout" : {
+         |        "vectors" : [ {
+         |          "type" : "VALIDITY",
+         |          "typeBitWidth" : 1
+         |        } ]
+         |      }
+         |    } ]
+         |  },
+         |  "batches" : [ {
+         |    "count" : 3,
+         |    "columns" : [ {
+         |      "name" : "a_struct",
+         |      "count" : 3,
+         |      "VALIDITY" : [ 1, 1, 1 ],
+         |      "children" : [ {
+         |        "name" : "i",
+         |        "count" : 3,
+         |        "VALIDITY" : [ 1, 1, 1 ],
+         |        "DATA" : [ 1, 2, 3 ]
+         |      } ]
+         |    }, {
+         |      "name" : "b_struct",
+         |      "count" : 3,
+         |      "VALIDITY" : [ 1, 0, 1 ],
+         |      "children" : [ {
+         |        "name" : "i",
+         |        "count" : 3,
+         |        "VALIDITY" : [ 1, 0, 1 ],
+         |        "DATA" : [ 1, 2, 3 ]
+         |      } ]
+         |    }, {
+         |      "name" : "c_struct",
+         |      "count" : 3,
+         |      "VALIDITY" : [ 1, 1, 1 ],
+         |      "children" : [ {
+         |        "name" : "i",
+         |        "count" : 3,
+         |        "VALIDITY" : [ 1, 0, 1 ],
+         |        "DATA" : [ 1, 2, 3 ]
+         |      } ]
+         |    }, {
+         |      "name" : "d_struct",
+         |      "count" : 3,
+         |      "VALIDITY" : [ 1, 0, 1 ],
+         |      "children" : [ {
+         |        "name" : "nested",
+         |        "count" : 3,
+         |        "VALIDITY" : [ 1, 0, 0 ],
+         |        "children" : [ {
+         |          "name" : "i",
+         |          "count" : 3,
+         |          "VALIDITY" : [ 1, 0, 0 ],
+         |          "DATA" : [ 1, 2, 0 ]
+         |        } ]
+         |      } ]
+         |    } ]
+         |  } ]
+         |}
+       """.stripMargin
+
+    val aStruct = Seq(Row(1), Row(2), Row(3))
+    val bStruct = Seq(Row(1), null, Row(3))
+    val cStruct = Seq(Row(1), Row(null), Row(3))
+    val dStruct = Seq(Row(Row(1)), null, Row(null))
+    val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map {
+      case (((a, b), c), d) => Row(a, b, c, d)
+    }
+
+    val rdd = sparkContext.parallelize(data)
+    val schema = new StructType()
+      .add("a_struct", new StructType().add("i", IntegerType, nullable = false), nullable = false)
+      .add("b_struct", new StructType().add("i", IntegerType, nullable = false), nullable = true)
+      .add("c_struct", new StructType().add("i", IntegerType, nullable = true), nullable = false)
+      .add("d_struct", new StructType().add("nested", new StructType().add("i", IntegerType)))
+    val df = spark.createDataFrame(rdd, schema)
+
+    collectAndValidate(df, json, "structData.json")
+  }
+
   test("partitioned DataFrame") {
     val json1 =
       s"""
@@ -1015,6 +1458,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
     spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch)
     val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i")
     val arrowPayloads = df.toArrowPayload.collect()
+    assert(arrowPayloads.length >= 4)
     val allocator = new RootAllocator(Long.MaxValue)
     val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator))
     var recordCount = 0
@@ -1039,7 +1483,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
     }
 
     runUnsupported { decimalData.toArrowPayload.collect() }
-    runUnsupported { arrayData.toDF().toArrowPayload.collect() }
     runUnsupported { mapData.toDF().toArrowPayload.collect() }
     runUnsupported { complexData.toArrowPayload.collect() }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2ff35a05/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
new file mode 100644
index 0000000..e9a6293
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.arrow
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.execution.vectorized.ArrowColumnVector
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class ArrowWriterSuite extends SparkFunSuite {
+
+  test("simple") {
+    def check(dt: DataType, data: Seq[Any]): Unit = {
+      val schema = new StructType().add("value", dt, nullable = true)
+      val writer = ArrowWriter.create(schema)
+      assert(writer.schema === schema)
+
+      data.foreach { datum =>
+        writer.write(InternalRow(datum))
+      }
+      writer.finish()
+
+      val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+      data.zipWithIndex.foreach {
+        case (null, rowId) => assert(reader.isNullAt(rowId))
+        case (datum, rowId) =>
+          val value = dt match {
+            case BooleanType => reader.getBoolean(rowId)
+            case ByteType => reader.getByte(rowId)
+            case ShortType => reader.getShort(rowId)
+            case IntegerType => reader.getInt(rowId)
+            case LongType => reader.getLong(rowId)
+            case FloatType => reader.getFloat(rowId)
+            case DoubleType => reader.getDouble(rowId)
+            case StringType => reader.getUTF8String(rowId)
+            case BinaryType => reader.getBinary(rowId)
+          }
+          assert(value === datum)
+      }
+
+      writer.root.close()
+    }
+    check(BooleanType, Seq(true, null, false))
+    check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte))
+    check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort))
+    check(IntegerType, Seq(1, 2, null, 4))
+    check(LongType, Seq(1L, 2L, null, 4L))
+    check(FloatType, Seq(1.0f, 2.0f, null, 4.0f))
+    check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
+    check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
+    check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()))
+  }
+
+  test("get multiple") {
+    def check(dt: DataType, data: Seq[Any]): Unit = {
+      val schema = new StructType().add("value", dt, nullable = false)
+      val writer = ArrowWriter.create(schema)
+      assert(writer.schema === schema)
+
+      data.foreach { datum =>
+        writer.write(InternalRow(datum))
+      }
+      writer.finish()
+
+      val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+      val values = dt match {
+        case BooleanType => reader.getBooleans(0, data.size)
+        case ByteType => reader.getBytes(0, data.size)
+        case ShortType => reader.getShorts(0, data.size)
+        case IntegerType => reader.getInts(0, data.size)
+        case LongType => reader.getLongs(0, data.size)
+        case FloatType => reader.getFloats(0, data.size)
+        case DoubleType => reader.getDoubles(0, data.size)
+      }
+      assert(values === data)
+
+      writer.root.close()
+    }
+    check(BooleanType, Seq(true, false))
+    check(ByteType, (0 until 10).map(_.toByte))
+    check(ShortType, (0 until 10).map(_.toShort))
+    check(IntegerType, (0 until 10))
+    check(LongType, (0 until 10).map(_.toLong))
+    check(FloatType, (0 until 10).map(_.toFloat))
+    check(DoubleType, (0 until 10).map(_.toDouble))
+  }
+
+  test("array") {
+    val schema = new StructType()
+      .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true)
+    val writer = ArrowWriter.create(schema)
+    assert(writer.schema === schema)
+
+    writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3))))
+    writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5))))
+    writer.write(InternalRow(null))
+    writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int])))
+    writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8))))
+    writer.finish()
+
+    val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+    val array0 = reader.getArray(0)
+    assert(array0.numElements() === 3)
+    assert(array0.getInt(0) === 1)
+    assert(array0.getInt(1) === 2)
+    assert(array0.getInt(2) === 3)
+
+    val array1 = reader.getArray(1)
+    assert(array1.numElements() === 2)
+    assert(array1.getInt(0) === 4)
+    assert(array1.getInt(1) === 5)
+
+    assert(reader.isNullAt(2))
+
+    val array3 = reader.getArray(3)
+    assert(array3.numElements() === 0)
+
+    val array4 = reader.getArray(4)
+    assert(array4.numElements() === 3)
+    assert(array4.getInt(0) === 6)
+    assert(array4.isNullAt(1))
+    assert(array4.getInt(2) === 8)
+
+    writer.root.close()
+  }
+
+  test("nested array") {
+    val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType)))
+    val writer = ArrowWriter.create(schema)
+    assert(writer.schema === schema)
+
+    writer.write(InternalRow(ArrayData.toArrayData(Array(
+      ArrayData.toArrayData(Array(1, 2, 3)),
+      ArrayData.toArrayData(Array(4, 5)),
+      null,
+      ArrayData.toArrayData(Array.empty[Int]),
+      ArrayData.toArrayData(Array(6, null, 8))))))
+    writer.write(InternalRow(null))
+    writer.write(InternalRow(ArrayData.toArrayData(Array.empty)))
+    writer.finish()
+
+    val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+    val array0 = reader.getArray(0)
+    assert(array0.numElements() === 5)
+
+    val array00 = array0.getArray(0)
+    assert(array00.numElements() === 3)
+    assert(array00.getInt(0) === 1)
+    assert(array00.getInt(1) === 2)
+    assert(array00.getInt(2) === 3)
+
+    val array01 = array0.getArray(1)
+    assert(array01.numElements() === 2)
+    assert(array01.getInt(0) === 4)
+    assert(array01.getInt(1) === 5)
+
+    assert(array0.isNullAt(2))
+
+    val array03 = array0.getArray(3)
+    assert(array03.numElements() === 0)
+
+    val array04 = array0.getArray(4)
+    assert(array04.numElements() === 3)
+    assert(array04.getInt(0) === 6)
+    assert(array04.isNullAt(1))
+    assert(array04.getInt(2) === 8)
+
+    assert(reader.isNullAt(1))
+
+    val array2 = reader.getArray(2)
+    assert(array2.numElements() === 0)
+
+    writer.root.close()
+  }
+
+  test("struct") {
+    val schema = new StructType()
+      .add("struct", new StructType().add("i", IntegerType).add("str", StringType))
+    val writer = ArrowWriter.create(schema)
+    assert(writer.schema === schema)
+
+    writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))
+    writer.write(InternalRow(InternalRow(null, null)))
+    writer.write(InternalRow(null))
+    writer.write(InternalRow(InternalRow(4, null)))
+    writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5"))))
+    writer.finish()
+
+    val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+    val struct0 = reader.getStruct(0, 2)
+    assert(struct0.getInt(0) === 1)
+    assert(struct0.getUTF8String(1) === UTF8String.fromString("str1"))
+
+    val struct1 = reader.getStruct(1, 2)
+    assert(struct1.isNullAt(0))
+    assert(struct1.isNullAt(1))
+
+    assert(reader.isNullAt(2))
+
+    val struct3 = reader.getStruct(3, 2)
+    assert(struct3.getInt(0) === 4)
+    assert(struct3.isNullAt(1))
+
+    val struct4 = reader.getStruct(4, 2)
+    assert(struct4.isNullAt(0))
+    assert(struct4.getUTF8String(1) === UTF8String.fromString("str5"))
+
+    writer.root.close()
+  }
+
+  test("nested struct") {
+    val schema = new StructType().add("struct",
+      new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType)))
+    val writer = ArrowWriter.create(schema)
+    assert(writer.schema === schema)
+
+    writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1")))))
+    writer.write(InternalRow(InternalRow(InternalRow(null, null))))
+    writer.write(InternalRow(InternalRow(null)))
+    writer.write(InternalRow(null))
+    writer.finish()
+
+    val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
+
+    val struct00 = reader.getStruct(0, 1).getStruct(0, 2)
+    assert(struct00.getInt(0) === 1)
+    assert(struct00.getUTF8String(1) === UTF8String.fromString("str1"))
+
+    val struct10 = reader.getStruct(1, 1).getStruct(0, 2)
+    assert(struct10.isNullAt(0))
+    assert(struct10.isNullAt(1))
+
+    val struct2 = reader.getStruct(2, 1)
+    assert(struct2.isNullAt(0))
+
+    assert(reader.isNullAt(3))
+
+    writer.root.close()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message