spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-13255] [SQL] Update vectorized reader to directly return ColumnarBatch instead of InternalRows.
Date Fri, 04 Mar 2016 23:15:56 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5f42c28b1 -> a6e2bd31f


[SPARK-13255] [SQL] Update vectorized reader to directly return ColumnarBatch instead of InternalRows.

## What changes were proposed in this pull request?

(Please fill in changes proposed in this fix)

Currently, the parquet reader returns rows one by one which is bad for performance. This patch
updates the reader to directly return ColumnarBatches. This is only enabled with whole stage
codegen, which is the only operator currently that is able to consume ColumnarBatches (instead
of rows). The current implementation is a bit of a hack to get this to work and we should
do
more refactoring of these low level interfaces to make this work better.

## How was this patch tested?

```
Results:
TPCDS:                             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
---------------------------------------------------------------------------------
q55 (before)                             8897 / 9265         12.9          77.2
q55                                      5486 / 5753         21.0          47.6
```

Author: Nong Li <nong@databricks.com>

Closes #11435 from nongli/spark-13255.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6e2bd31
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6e2bd31
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6e2bd31

Branch: refs/heads/master
Commit: a6e2bd31f52f9e9452e52ab5b846de3dee8b98a7
Parents: 5f42c28
Author: Nong Li <nong@databricks.com>
Authored: Fri Mar 4 15:15:48 2016 -0800
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Fri Mar 4 15:15:48 2016 -0800

----------------------------------------------------------------------
 .../parquet/UnsafeRowParquetRecordReader.java   | 29 ++++++--
 .../execution/vectorized/ColumnVectorUtils.java | 57 +++++++++++++++
 .../sql/execution/vectorized/ColumnarBatch.java | 12 ++++
 .../vectorized/OnHeapColumnVector.java          |  3 -
 .../spark/sql/execution/ExistingRDD.scala       | 67 ++++++++++++++++--
 .../datasources/DataSourceStrategy.scala        | 72 +++++++++++++++++--
 .../execution/datasources/SqlNewHadoopRDD.scala |  8 ++-
 .../datasources/parquet/ParquetIOSuite.scala    |  8 +--
 .../parquet/ParquetReadBenchmark.scala          | 73 +++++++++++++++-----
 9 files changed, 284 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 57dbd7c..7d768b1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -37,7 +37,6 @@ import org.apache.parquet.schema.PrimitiveType;
 import org.apache.parquet.schema.Type;
 
 import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
 import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
@@ -57,10 +56,14 @@ import static org.apache.parquet.column.ValuesType.*;
  *
  * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
  * All of these can be handled efficiently and easily with codegen.
+ *
+ * This class can either return InternalRows or ColumnarBatches. With whole stage codegen
+ * enabled, this class returns ColumnarBatches which offers significant performance gains.
+ * TODO: make this always return ColumnarBatches.
  */
-public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<InternalRow>
{
+public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase<Object>
{
   /**
-   * Batch of unsafe rows that we assemble and the current index we've returned. Everytime
this
+   * Batch of unsafe rows that we assemble and the current index we've returned. Every time
this
    * batch is used up (batchIdx == numBatched), we populated the batch.
    */
   private UnsafeRow[] rows = new UnsafeRow[64];
@@ -115,12 +118,16 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
    * code between the path that uses the MR decoders and the vectorized ones.
    *
    * TODOs:
-   *  - Implement all the encodings to support vectorized.
    *  - Implement v2 page formats (just make sure we create the correct decoders).
    */
   private ColumnarBatch columnarBatch;
 
   /**
+   * If true, this class returns batches instead of rows.
+   */
+  private boolean returnColumnarBatch;
+
+  /**
    * The default config on whether columnarBatch should be offheap.
    */
   private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP;
@@ -169,6 +176,8 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
 
   @Override
   public boolean nextKeyValue() throws IOException, InterruptedException {
+    if (returnColumnarBatch) return nextBatch();
+
     if (batchIdx >= numBatched) {
       if (vectorizedDecode()) {
         if (!nextBatch()) return false;
@@ -181,7 +190,9 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
   }
 
   @Override
-  public InternalRow getCurrentValue() throws IOException, InterruptedException {
+  public Object getCurrentValue() throws IOException, InterruptedException {
+    if (returnColumnarBatch) return columnarBatch;
+
     if (vectorizedDecode()) {
       return columnarBatch.getRow(batchIdx - 1);
     } else {
@@ -211,6 +222,14 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
   }
 
   /**
+   * Can be called before any rows are returned to enable returning columnar batches directly.
+   */
+  public void enableReturningBatches() {
+    assert(vectorizedDecode());
+    returnColumnarBatch = true;
+  }
+
+  /**
    * Advances to the next batch of rows. Returns false if there are no more.
    */
   public boolean nextBatch() throws IOException {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 681ace3..68f146f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -26,9 +26,11 @@ import org.apache.commons.lang.NotImplementedException;
 
 import org.apache.spark.memory.MemoryMode;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
 import org.apache.spark.sql.types.*;
 import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
 
 /**
  * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly
@@ -37,6 +39,61 @@ import org.apache.spark.unsafe.types.CalendarInterval;
  */
 public class ColumnVectorUtils {
   /**
+   * Populates the entire `col` with `row[fieldIdx]`
+   */
+  public static void populate(ColumnVector col, InternalRow row, int fieldIdx) {
+    int capacity = col.capacity;
+    DataType t = col.dataType();
+
+    if (row.isNullAt(fieldIdx)) {
+      col.putNulls(0, capacity);
+    } else {
+      if (t == DataTypes.BooleanType) {
+        col.putBooleans(0, capacity, row.getBoolean(fieldIdx));
+      } else if (t == DataTypes.ByteType) {
+        col.putBytes(0, capacity, row.getByte(fieldIdx));
+      } else if (t == DataTypes.ShortType) {
+        col.putShorts(0, capacity, row.getShort(fieldIdx));
+      } else if (t == DataTypes.IntegerType) {
+        col.putInts(0, capacity, row.getInt(fieldIdx));
+      } else if (t == DataTypes.LongType) {
+        col.putLongs(0, capacity, row.getLong(fieldIdx));
+      } else if (t == DataTypes.FloatType) {
+        col.putFloats(0, capacity, row.getFloat(fieldIdx));
+      } else if (t == DataTypes.DoubleType) {
+        col.putDoubles(0, capacity, row.getDouble(fieldIdx));
+      } else if (t == DataTypes.StringType) {
+        UTF8String v = row.getUTF8String(fieldIdx);
+        byte[] bytes = v.getBytes();
+        for (int i = 0; i < capacity; i++) {
+          col.putByteArray(i, bytes);
+        }
+      } else if (t instanceof DecimalType) {
+        DecimalType dt = (DecimalType)t;
+        Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale());
+        if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
+          col.putInts(0, capacity, (int)d.toUnscaledLong());
+        } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
+          col.putLongs(0, capacity, d.toUnscaledLong());
+        } else {
+          final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
+          byte[] bytes = integer.toByteArray();
+          for (int i = 0; i < capacity; i++) {
+            col.putByteArray(i, bytes, 0, bytes.length);
+          }
+        }
+      } else if (t instanceof CalendarIntervalType) {
+        CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
+        col.getChildColumn(0).putInts(0, capacity, c.months);
+        col.getChildColumn(1).putLongs(0, capacity, c.microseconds);
+      } else if (t instanceof DateType) {
+        Date date = (Date)row.get(fieldIdx, t);
+        col.putInts(0, capacity, DateTimeUtils.fromJavaDate(date));
+      }
+    }
+  }
+
+  /**
    * Returns the array data as the java primitive array.
    * For example, an array of IntegerType will return an int[].
    * Throws exceptions for unhandled schemas.

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 2a78058..1876367 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -22,6 +22,7 @@ import java.util.Iterator;
 import org.apache.commons.lang.NotImplementedException;
 
 import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.Column;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -316,6 +317,17 @@ public final class ColumnarBatch {
   public ColumnVector column(int ordinal) { return columns[ordinal]; }
 
   /**
+   * Sets (replaces) the column at `ordinal` with column. This can be used to do very efficient
+   * projections.
+   */
+  public void setColumn(int ordinal, ColumnVector column) {
+    if (column instanceof OffHeapColumnVector) {
+      throw new NotImplementedException("Need to ref count columns.");
+    }
+    columns[ordinal] = column;
+  }
+
+  /**
    * Returns the row in this batch at `rowId`. Returned row is reused across calls.
    */
   public ColumnarBatch.Row getRow(int rowId) {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 305e84a..03160d1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -62,9 +62,6 @@ public final class OnHeapColumnVector extends ColumnVector {
 
   @Override
   public final void close() {
-    nulls = null;
-    intData = null;
-    doubleData = null;
   }
 
   //

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 2cbe3f2..36e656b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -139,9 +139,14 @@ private[sql] case class PhysicalRDD(
   // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
   // never requires UnsafeRow as input.
   override protected def doProduce(ctx: CodegenContext): String = {
+    val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
     val input = ctx.freshName("input")
+    val idx = ctx.freshName("batchIdx")
+    val batch = ctx.freshName("batch")
     // PhysicalRDD always just has one input
     ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+    ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
+    ctx.addMutableState("int", idx, s"$idx = 0;")
 
     val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
     val row = ctx.freshName("row")
@@ -149,14 +154,62 @@ private[sql] case class PhysicalRDD(
     ctx.INPUT_ROW = row
     ctx.currentVars = null
     val columns = exprs.map(_.gen(ctx))
+
+    // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine
this
+    // by looking at the first value of the RDD and then calling the function which will
process
+    // the remaining. It is faster to return batches.
+    // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to
know
+    // here which path to use. Fix this.
+
+
+    val scanBatches = ctx.freshName("processBatches")
+    ctx.addNewFunction(scanBatches,
+      s"""
+      | private void $scanBatches() throws java.io.IOException {
+      |  while (true) {
+      |     int numRows = $batch.numRows();
+      |     if ($idx == 0) $numOutputRows.add(numRows);
+      |
+      |     while ($idx < numRows) {
+      |       InternalRow $row = $batch.getRow($idx++);
+      |       ${columns.map(_.code).mkString("\n").trim}
+      |       ${consume(ctx, columns).trim}
+      |       if (shouldStop()) return;
+      |     }
+      |
+      |     if (!$input.hasNext()) {
+      |       $batch = null;
+      |       break;
+      |     }
+      |     $batch = ($columnarBatchClz)$input.next();
+      |     $idx = 0;
+      |   }
+      | }""".stripMargin)
+
+    val scanRows = ctx.freshName("processRows")
+    ctx.addNewFunction(scanRows,
+      s"""
+       | private void $scanRows(InternalRow $row) throws java.io.IOException {
+       |   while (true) {
+       |     $numOutputRows.add(1);
+       |     ${columns.map(_.code).mkString("\n").trim}
+       |     ${consume(ctx, columns).trim}
+       |     if (shouldStop()) return;
+       |     if (!$input.hasNext()) break;
+       |     $row = (InternalRow)$input.next();
+       |   }
+       | }""".stripMargin)
+
     s"""
-       | while ($input.hasNext()) {
-       |   InternalRow $row = (InternalRow) $input.next();
-       |   $numOutputRows.add(1);
-       |   ${columns.map(_.code).mkString("\n").trim}
-       |   ${consume(ctx, columns).trim}
-       |   if (shouldStop()) {
-       |     return;
+       | if ($batch != null) {
+       |   $scanBatches();
+       | } else if ($input.hasNext()) {
+       |   Object value = $input.next();
+       |   if (value instanceof $columnarBatchClz) {
+       |     $batch = ($columnarBatchClz)value;
+       |     $scanBatches();
+       |   } else {
+       |     $scanRows((InternalRow)value);
        |   }
        | }
      """.stripMargin

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index ceb3510..69a6d23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -23,8 +23,9 @@ import org.apache.spark.{Logging, TaskContext}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
+import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical
@@ -33,8 +34,9 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.command.ExecutedCommand
+import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils}
 import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.{SerializableConfiguration, Utils}
 import org.apache.spark.util.collection.BitSet
@@ -220,6 +222,44 @@ private[sql] object DataSourceStrategy extends Strategy with Logging
{
     sparkPlan
   }
 
+  /**
+   * Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns
can
+   * either come from `input` (columns scanned from the data source) or from the partitioning
+   * values (data from `partitionValues`). This is done *once* per physical partition. When
+   * the column is from `input`, it just references the same underlying column. When using
+   * partition columns, the column is populated once.
+   * TODO: there's probably a cleaner way to do this.
+   */
+  private def projectedColumnBatch(
+      input: ColumnarBatch,
+      requiredColumns: Seq[Attribute],
+      dataColumns: Seq[Attribute],
+      partitionColumnSchema: StructType,
+      partitionValues: InternalRow) : ColumnarBatch = {
+    val result = ColumnarBatch.allocate(StructType.fromAttributes(requiredColumns))
+    var resultIdx = 0
+    var inputIdx = 0
+
+    while (resultIdx < requiredColumns.length) {
+      val attr = requiredColumns(resultIdx)
+      if (inputIdx < dataColumns.length && requiredColumns(resultIdx) == dataColumns(inputIdx))
{
+        result.setColumn(resultIdx, input.column(inputIdx))
+        inputIdx += 1
+      } else {
+        require(partitionColumnSchema.fields.filter(_.name.equals(attr.name)).length == 1)
+        var partitionIdx = 0
+        partitionColumnSchema.fields.foreach { f => {
+          if (f.name.equals(attr.name)) {
+            ColumnVectorUtils.populate(result.column(resultIdx), partitionValues, partitionIdx)
+          }
+          partitionIdx += 1
+        }}
+      }
+      resultIdx += 1
+    }
+    result
+  }
+
   private def mergeWithPartitionValues(
       requiredColumns: Seq[Attribute],
       dataColumns: Seq[Attribute],
@@ -239,7 +279,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
         }
       }
 
-      val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) =>
{
+      val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Object]) =>
{
         // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow`
and
         // `UnsafeProjection`.  Because the projection may also adjust column order.
         val mutableJoinedRow = new JoinedRow()
@@ -247,9 +287,27 @@ private[sql] object DataSourceStrategy extends Strategy with Logging
{
         val unsafeProjection =
           UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns)
 
-        iterator.map { unsafeDataRow =>
-          unsafeProjection(mutableJoinedRow(unsafeDataRow, unsafePartitionValues))
-        }
+        // If we are returning batches directly, we need to augment them with the partitioning
+        // columns. We want to do this without a row by row operation.
+        var columnBatch: ColumnarBatch = null
+        var mergedBatch: ColumnarBatch = null
+
+        iterator.map { input => {
+          if (input.isInstanceOf[InternalRow]) {
+            unsafeProjection(mutableJoinedRow(
+              input.asInstanceOf[InternalRow], unsafePartitionValues))
+          } else {
+            require(input.isInstanceOf[ColumnarBatch])
+            val inputBatch = input.asInstanceOf[ColumnarBatch]
+            if (inputBatch != mergedBatch) {
+              mergedBatch = inputBatch
+              columnBatch = projectedColumnBatch(inputBatch, requiredColumns,
+                dataColumns, partitionColumnSchema, partitionValues)
+            }
+            columnBatch.setNumRows(inputBatch.numRows())
+            columnBatch
+          }
+        }}
       }
 
       // This is an internal RDD whose call site the user should not be concerned with
@@ -257,7 +315,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
       // the call site may add up.
       Utils.withDummyCallSite(dataRows.sparkContext) {
         new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false)
-      }
+      }.asInstanceOf[RDD[InternalRow]]
     } else {
       dataRows
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index f4271d1..c4c7ecc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -102,6 +102,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
     sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean
   protected val enableVectorizedParquetReader: Boolean =
     sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
+  protected val enableWholestageCodegen: Boolean =
+    sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean
 
   override def getPartitions: Array[SparkPartition] = {
     val conf = getConf(isDriverSide = true)
@@ -179,7 +181,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
           parquetReader.close()
         } else {
           reader = parquetReader.asInstanceOf[RecordReader[Void, V]]
-          if (enableVectorizedParquetReader) parquetReader.resultBatch()
+          if (enableVectorizedParquetReader) {
+            parquetReader.resultBatch()
+            // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
+            if (enableWholestageCodegen) parquetReader.enableReturningBatches();
+          }
         }
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index c85eedd..cf8a9fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -37,7 +37,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser}
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.internal.SQLConf
@@ -683,7 +683,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext
{
           reader.initialize(file, null)
           val result = mutable.ArrayBuffer.empty[(Int, String)]
           while (reader.nextKeyValue()) {
-            val row = reader.getCurrentValue
+            val row = reader.getCurrentValue.asInstanceOf[InternalRow]
             val v = (row.getInt(0), row.getString(1))
             result += v
           }
@@ -700,7 +700,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext
{
           reader.initialize(file, ("_2" :: Nil).asJava)
           val result = mutable.ArrayBuffer.empty[(String)]
           while (reader.nextKeyValue()) {
-            val row = reader.getCurrentValue
+            val row = reader.getCurrentValue.asInstanceOf[InternalRow]
             result += row.getString(0)
           }
           assert(data.map(_._2) == result)
@@ -716,7 +716,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext
{
           reader.initialize(file, ("_2" :: "_1" :: Nil).asJava)
           val result = mutable.ArrayBuffer.empty[(String, Int)]
           while (reader.nextKeyValue()) {
-            val row = reader.getCurrentValue
+            val row = reader.getCurrentValue.asInstanceOf[InternalRow]
             val v = (row.getString(0), row.getInt(1))
             result += v
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6e2bd31/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
index 14dbdf3..38c3618 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
@@ -22,8 +22,9 @@ import scala.collection.JavaConverters._
 import scala.util.Try
 
 import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.util.{Benchmark, Utils}
 
 /**
@@ -94,14 +95,14 @@ object ParquetReadBenchmark {
 
         val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray
         // Driving the parquet reader directly without Spark.
-        parquetReaderBenchmark.addCase("ParquetReader") { num =>
+        parquetReaderBenchmark.addCase("ParquetReader Non-Vectorized") { num =>
           var sum = 0L
           files.map(_.asInstanceOf[String]).foreach { p =>
             val reader = new UnsafeRowParquetRecordReader
             reader.initialize(p, ("id" :: Nil).asJava)
 
             while (reader.nextKeyValue()) {
-              val record = reader.getCurrentValue
+              val record = reader.getCurrentValue.asInstanceOf[InternalRow]
               if (!record.isNullAt(0)) sum += record.getInt(0)
             }
             reader.close()
@@ -109,7 +110,7 @@ object ParquetReadBenchmark {
         }
 
         // Driving the parquet reader in batch mode directly.
-        parquetReaderBenchmark.addCase("ParquetReader(Batched)") { num =>
+        parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num =>
           var sum = 0L
           files.map(_.asInstanceOf[String]).foreach { p =>
             val reader = new UnsafeRowParquetRecordReader
@@ -132,7 +133,7 @@ object ParquetReadBenchmark {
         }
 
         // Decoding in vectorized but having the reader return rows.
-        parquetReaderBenchmark.addCase("ParquetReader(Batch -> Row)") { num =>
+        parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num =>
           var sum = 0L
           files.map(_.asInstanceOf[String]).foreach { p =>
             val reader = new UnsafeRowParquetRecordReader
@@ -156,9 +157,9 @@ object ParquetReadBenchmark {
         Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
         SQL Single Int Column Scan:         Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
         -------------------------------------------------------------------------------------------
-        SQL Parquet Vectorized                    657 /  778         23.9          41.8 
     1.0X
-        SQL Parquet MR                           1606 / 1731          9.8         102.1 
     0.4X
-        SQL Parquet Non-Vectorized               1133 / 1216         13.9          72.1 
     0.6X
+        SQL Parquet Vectorized                    215 /  262         73.0          13.7 
     1.0X
+        SQL Parquet MR                           1946 / 2083          8.1         123.7 
     0.1X
+        SQL Parquet Non-Vectorized               1079 / 1213         14.6          68.6 
     0.2X
         */
         sqlBenchmark.run()
 
@@ -166,9 +167,9 @@ object ParquetReadBenchmark {
         Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
         Parquet Reader Single Int Column    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
         -------------------------------------------------------------------------------------------
-        ParquetReader                             565 /  609         27.8          35.9 
     1.0X
-        ParquetReader(Batched)                    165 /  174         95.3          10.5 
     3.4X
-        ParquetReader(Batch -> Row)               158 /  188         99.3          10.1
      3.6X
+        ParquetReader Non-Vectorized              610 /  737         25.8          38.8 
     1.0X
+        ParquetReader Vectorized                  123 /  152        127.8           7.8 
     5.0X
+        ParquetReader Vectorized -> Row           165 /  180         95.2          10.5
      3.7X
         */
         parquetReaderBenchmark.run()
       }
@@ -209,7 +210,7 @@ object ParquetReadBenchmark {
             val reader = new UnsafeRowParquetRecordReader
             reader.initialize(p, null)
             while (reader.nextKeyValue()) {
-              val record = reader.getCurrentValue
+              val record = reader.getCurrentValue.asInstanceOf[InternalRow]
               if (!record.isNullAt(0)) sum1 += record.getInt(0)
               if (!record.isNullAt(1)) sum2 += record.getUTF8String(1).numBytes()
             }
@@ -221,10 +222,10 @@ object ParquetReadBenchmark {
         Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
         Int and String Scan:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
         -------------------------------------------------------------------------------------------
-        SQL Parquet Vectorized                   1025 / 1180         10.2          97.8 
     1.0X
-        SQL Parquet MR                           2157 / 2222          4.9         205.7 
     0.5X
-        SQL Parquet Non-vectorized               1450 / 1466          7.2         138.3 
     0.7X
-        ParquetReader Non-vectorized             1005 / 1022         10.4          95.9 
     1.0X
+        SQL Parquet Vectorized                    628 /  720         16.7          59.9 
     1.0X
+        SQL Parquet MR                           1905 / 2239          5.5         181.7 
     0.3X
+        SQL Parquet Non-vectorized               1429 / 1732          7.3         136.3 
     0.4X
+        ParquetReader Non-vectorized              989 / 1357         10.6          94.3 
     0.6X
         */
         benchmark.run()
       }
@@ -255,17 +256,53 @@ object ParquetReadBenchmark {
         Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
         String Dictionary:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
         -------------------------------------------------------------------------------------------
-        SQL Parquet Vectorized                    578 /  593         18.1          55.1 
     1.0X
-        SQL Parquet MR                           1021 / 1032         10.3          97.4 
     0.6X
+        SQL Parquet Vectorized                    329 /  337         31.9          31.4 
     1.0X
+        SQL Parquet MR                           1131 / 1325          9.3         107.8 
     0.3X
         */
         benchmark.run()
       }
     }
   }
 
+  def partitionTableScanBenchmark(values: Int): Unit = {
+    withTempPath { dir =>
+      withTempTable("t1", "tempTable") {
+        sqlContext.range(values).registerTempTable("t1")
+        sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1")
+          .write.partitionBy("p").parquet(dir.getCanonicalPath)
+        sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable")
+
+        val benchmark = new Benchmark("Partitioned Table", values)
+
+        benchmark.addCase("Read data column") { iter =>
+          sqlContext.sql("select sum(id) from tempTable").collect
+        }
+
+        benchmark.addCase("Read partition column") { iter =>
+          sqlContext.sql("select sum(p) from tempTable").collect
+        }
+
+        benchmark.addCase("Read both columns") { iter =>
+          sqlContext.sql("select sum(p), sum(id) from tempTable").collect
+        }
+
+        /*
+        Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
+        Partitioned Table:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
+        -------------------------------------------------------------------------------------------
+        Read data column                          191 /  250         82.1          12.2 
     1.0X
+        Read partition column                      82 /   86        192.4           5.2 
     2.3X
+        Read both columns                         220 /  248         71.5          14.0 
     0.9X
+         */
+        benchmark.run()
+      }
+    }
+  }
+
   def main(args: Array[String]): Unit = {
     intScanBenchmark(1024 * 1024 * 15)
     intStringScanBenchmark(1024 * 1024 * 10)
     stringDictionaryScanBenchmark(1024 * 1024 * 10)
+    partitionTableScanBenchmark(1024 * 1024 * 15)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message