spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-12992] [SQL] Update parquet reader to support more types when decoding to ColumnarBatch.
Date Wed, 03 Feb 2016 00:33:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master 672032d0a -> 21112e8a1


[SPARK-12992] [SQL] Update parquet reader to support more types when decoding to ColumnarBatch.

This patch implements support for more types when doing the vectorized decode. There are
a few more types remaining but they should be very straightforward after this. This code
has a few copy and paste pieces but they are difficult to eliminate due to performance
considerations.

Specifically, this patch adds support for:
  - String, Long, Byte types
  - Dictionary encoding for those types.

Author: Nong Li <nong@databricks.com>

Closes #10908 from nongli/spark-12992.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21112e8a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21112e8a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21112e8a

Branch: refs/heads/master
Commit: 21112e8a14c042ccef4312079672108a1082a95e
Parents: 672032d
Author: Nong Li <nong@databricks.com>
Authored: Tue Feb 2 16:33:21 2016 -0800
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Tue Feb 2 16:33:21 2016 -0800

----------------------------------------------------------------------
 .../parquet/UnsafeRowParquetRecordReader.java   | 146 +++++++++++++++--
 .../parquet/VectorizedPlainValuesReader.java    |  45 +++++-
 .../parquet/VectorizedRleValuesReader.java      | 160 ++++++++++++++++++-
 .../parquet/VectorizedValuesReader.java         |   5 +
 .../sql/execution/vectorized/ColumnVector.java  |   7 +-
 .../parquet/ParquetEncodingSuite.scala          |  82 ++++++++++
 6 files changed, 424 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21112e8a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index 17adfec..b5dddb9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -21,6 +21,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.List;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.hadoop.mapreduce.InputSplit;
 import org.apache.hadoop.mapreduce.TaskAttemptContext;
 import org.apache.parquet.Preconditions;
@@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
 import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
 import org.apache.spark.sql.execution.vectorized.ColumnVector;
 import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
+import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.types.UTF8String;
@@ -207,13 +209,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
 
     int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned);
     for (int i = 0; i < columnReaders.length; ++i) {
-      switch (columnReaders[i].descriptor.getType()) {
-        case INT32:
-          columnReaders[i].readIntBatch(num, columnarBatch.column(i));
-          break;
-        default:
-          throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType());
-      }
+      columnReaders[i].readBatch(num, columnarBatch.column(i));
     }
     rowsReturned += num;
     columnarBatch.setNumRows(num);
@@ -237,7 +233,8 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
 
       // TODO: Be extremely cautious in what is supported. Expand this.
       if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
-          originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE)
{
+          originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE
&&
+          originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16)
{
         throw new IOException("Unsupported type: " + t);
       }
       if (originalTypes[i] == OriginalType.DECIMAL &&
@@ -465,6 +462,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
     private boolean useDictionary;
 
     /**
+     * If useDictionary is true, the staging vector used to decode the ids.
+     */
+    private ColumnVector dictionaryIds;
+
+    /**
      * Maximum definition level for this column.
      */
     private final int maxDefLevel;
@@ -587,9 +589,8 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
 
     /**
      * Reads `total` values from this columnReader into column.
-     * TODO: implement the other encodings.
      */
-    private void readIntBatch(int total, ColumnVector column) throws IOException {
+    private void readBatch(int total, ColumnVector column) throws IOException {
       int rowId = 0;
       while (total > 0) {
         // Compute the number of values we want to read in this page.
@@ -599,21 +600,134 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
           leftInPage = (int)(endOfPageValueCount - valuesRead);
         }
         int num = Math.min(total, leftInPage);
-        defColumn.readIntegers(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0);
-
-        // Remap the values if it is dictionary encoded.
         if (useDictionary) {
-          for (int i = rowId; i < rowId + num; ++i) {
-            column.putInt(i, dictionary.decodeToInt(column.getInt(i)));
+          // Data is dictionary encoded. We will vector decode the ids and then resolve the
values.
+          if (dictionaryIds == null) {
+            dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
+          } else {
+            dictionaryIds.reset();
+            dictionaryIds.reserve(total);
+          }
+          // Read and decode dictionary ids.
+          readIntBatch(rowId, num, dictionaryIds);
+          decodeDictionaryIds(rowId, num, column);
+        } else {
+          switch (descriptor.getType()) {
+            case INT32:
+              readIntBatch(rowId, num, column);
+              break;
+            case INT64:
+              readLongBatch(rowId, num, column);
+              break;
+            case BINARY:
+              readBinaryBatch(rowId, num, column);
+              break;
+            default:
+              throw new IOException("Unsupported type: " + descriptor.getType());
           }
         }
+
         valuesRead += num;
         rowId += num;
         total -= num;
       }
     }
 
+    /**
+     * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
+     */
+    private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
+      switch (descriptor.getType()) {
+        case INT32:
+          if (column.dataType() == DataTypes.IntegerType) {
+            for (int i = rowId; i < rowId + num; ++i) {
+              column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
+            }
+          } else if (column.dataType() == DataTypes.ByteType) {
+            for (int i = rowId; i < rowId + num; ++i) {
+              column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i)));
+            }
+          } else {
+            throw new NotImplementedException("Unimplemented type: " + column.dataType());
+          }
+          break;
+
+        case INT64:
+          for (int i = rowId; i < rowId + num; ++i) {
+            column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+          }
+          break;
+
+        case BINARY:
+          // TODO: this is incredibly inefficient as it blows up the dictionary right here.
We
+          // need to do this better. We should probably add the dictionary data to the ColumnVector
+          // and reuse it across batches. This should mean adding a ByteArray would just
update
+          // the length and offset.
+          for (int i = rowId; i < rowId + num; ++i) {
+            Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+            column.putByteArray(i, v.getBytes());
+          }
+          break;
+
+        default:
+          throw new NotImplementedException("Unsupported type: " + descriptor.getType());
+      }
+
+      if (dictionaryIds.numNulls() > 0) {
+        // Copy the NULLs over.
+        // TODO: we can improve this by decoding the NULLs directly into column. This would
+        // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and
then
+        // just do the ID remapping as above.
+        for (int i = 0; i < num; ++i) {
+          if (dictionaryIds.getIsNull(rowId + i)) {
+            column.putNull(rowId + i);
+          }
+        }
+      }
+    }
+
+    /**
+     * For all the read*Batch functions, reads `num` values from this columnReader into column.
It
+     * is guaranteed that num is smaller than the number of values left in the current page.
+     */
+
+    private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException
{
+      // This is where we implement support for the valid type conversions.
+      // TODO: implement remaining type conversions
+      if (column.dataType() == DataTypes.IntegerType) {
+        defColumn.readIntegers(
+            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0);
+      } else if (column.dataType() == DataTypes.ByteType) {
+        defColumn.readBytes(
+            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+      } else {
+        throw new NotImplementedException("Unimplemented type: " + column.dataType());
+      }
+    }
+
+    private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException
{
+      // This is where we implement support for the valid type conversions.
+      // TODO: implement remaining type conversions
+      if (column.dataType() == DataTypes.LongType) {
+        defColumn.readLongs(
+            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+      } else {
+        throw new NotImplementedException("Unimplemented type: " + column.dataType());
+      }
+    }
+
+    private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException
{
+      // This is where we implement support for the valid type conversions.
+      // TODO: implement remaining type conversions
+      if (column.isArray()) {
+        defColumn.readBinarys(
+            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+      } else {
+        throw new NotImplementedException("Unimplemented type: " + column.dataType());
+      }
+    }
+
+
     private void readPage() throws IOException {
       DataPage page = pageReader.readPage();
       // TODO: Why is this a visitor?

http://git-wip-us.apache.org/repos/asf/spark/blob/21112e8a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index dac0c52..cec2418 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.parquet;
 
 import java.io.IOException;
 
+import org.apache.spark.sql.Column;
 import org.apache.spark.sql.execution.vectorized.ColumnVector;
 import org.apache.spark.unsafe.Platform;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.parquet.column.values.ValuesReader;
+import org.apache.parquet.io.api.Binary;
 
 /**
  * An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
@@ -52,15 +55,53 @@ public class VectorizedPlainValuesReader extends ValuesReader implements
Vectori
   }
 
   @Override
-  public void readIntegers(int total, ColumnVector c, int rowId) {
+  public final void readIntegers(int total, ColumnVector c, int rowId) {
     c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
     offset += 4 * total;
   }
 
   @Override
-  public int readInteger() {
+  public final void readLongs(int total, ColumnVector c, int rowId) {
+    c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
+    offset += 8 * total;
+  }
+
+  @Override
+  public final void readBytes(int total, ColumnVector c, int rowId) {
+    for (int i = 0; i < total; i++) {
+      // Bytes are stored as a 4-byte little endian int. Just read the first byte.
+      // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
+      c.putInt(rowId + i, buffer[offset]);
+      offset += 4;
+    }
+  }
+
+  @Override
+  public final int readInteger() {
     int v = Platform.getInt(buffer, offset);
     offset += 4;
     return v;
   }
+
+  @Override
+  public final long readLong() {
+    long v = Platform.getLong(buffer, offset);
+    offset += 8;
+    return v;
+  }
+
+  @Override
+  public final byte readByte() {
+    return (byte)readInteger();
+  }
+
+  @Override
+  public final void readBinary(int total, ColumnVector v, int rowId) {
+    for (int i = 0; i < total; i++) {
+      int len = readInteger();
+      int start = offset;
+      offset += len;
+      v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/21112e8a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 493ec9d..9bfd74d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -17,12 +17,16 @@
 
 package org.apache.spark.sql.execution.datasources.parquet;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.parquet.Preconditions;
 import org.apache.parquet.bytes.BytesUtils;
 import org.apache.parquet.column.values.ValuesReader;
 import org.apache.parquet.column.values.bitpacking.BytePacker;
 import org.apache.parquet.column.values.bitpacking.Packer;
 import org.apache.parquet.io.ParquetDecodingException;
+import org.apache.parquet.io.api.Binary;
+
+import org.apache.spark.sql.Column;
 import org.apache.spark.sql.execution.vectorized.ColumnVector;
 
 /**
@@ -35,7 +39,8 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector;
  *  - Definition/Repetition levels
  *  - Dictionary ids.
  */
-public final class VectorizedRleValuesReader extends ValuesReader {
+public final class VectorizedRleValuesReader extends ValuesReader
+    implements VectorizedValuesReader {
   // Current decoding mode. The encoded data contains groups of either run length encoded
data
   // (RLE) or bit packed data. Each group contains a header that indicates which group it
is and
   // the number of values in the group.
@@ -121,6 +126,7 @@ public final class VectorizedRleValuesReader extends ValuesReader {
     return readInteger();
   }
 
+
   @Override
   public int readInteger() {
     if (this.currentCount == 0) { this.readNextGroup(); }
@@ -138,7 +144,9 @@ public final class VectorizedRleValuesReader extends ValuesReader {
   /**
    * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader
    * reads the definition levels and then will read from `data` for the non-null values.
-   * If the value is null, c will be populated with `nullValue`.
+   * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is
only
+   * necessary for readIntegers because we also use it to decode dictionaryIds and want to
make
+   * sure it always has a value in range.
    *
    * This is a batched version of this logic:
    *  if (this.readInt() == level) {
@@ -180,6 +188,154 @@ public final class VectorizedRleValuesReader extends ValuesReader {
     }
   }
 
+  // TODO: can this code duplication be removed without a perf penalty?
+  public void readBytes(int total, ColumnVector c,
+                        int rowId, int level, VectorizedValuesReader data) {
+    int left = total;
+    while (left > 0) {
+      if (this.currentCount == 0) this.readNextGroup();
+      int n = Math.min(left, this.currentCount);
+      switch (mode) {
+        case RLE:
+          if (currentValue == level) {
+            data.readBytes(n, c, rowId);
+            c.putNotNulls(rowId, n);
+          } else {
+            c.putNulls(rowId, n);
+          }
+          break;
+        case PACKED:
+          for (int i = 0; i < n; ++i) {
+            if (currentBuffer[currentBufferIdx++] == level) {
+              c.putByte(rowId + i, data.readByte());
+              c.putNotNull(rowId + i);
+            } else {
+              c.putNull(rowId + i);
+            }
+          }
+          break;
+      }
+      rowId += n;
+      left -= n;
+      currentCount -= n;
+    }
+  }
+
+  public void readLongs(int total, ColumnVector c, int rowId, int level,
+                        VectorizedValuesReader data) {
+    int left = total;
+    while (left > 0) {
+      if (this.currentCount == 0) this.readNextGroup();
+      int n = Math.min(left, this.currentCount);
+      switch (mode) {
+        case RLE:
+          if (currentValue == level) {
+            data.readLongs(n, c, rowId);
+            c.putNotNulls(rowId, n);
+          } else {
+            c.putNulls(rowId, n);
+          }
+          break;
+        case PACKED:
+          for (int i = 0; i < n; ++i) {
+            if (currentBuffer[currentBufferIdx++] == level) {
+              c.putLong(rowId + i, data.readLong());
+              c.putNotNull(rowId + i);
+            } else {
+              c.putNull(rowId + i);
+            }
+          }
+          break;
+      }
+      rowId += n;
+      left -= n;
+      currentCount -= n;
+    }
+  }
+
+  public void readBinarys(int total, ColumnVector c, int rowId, int level,
+                        VectorizedValuesReader data) {
+    int left = total;
+    while (left > 0) {
+      if (this.currentCount == 0) this.readNextGroup();
+      int n = Math.min(left, this.currentCount);
+      switch (mode) {
+        case RLE:
+          if (currentValue == level) {
+            c.putNotNulls(rowId, n);
+            data.readBinary(n, c, rowId);
+          } else {
+            c.putNulls(rowId, n);
+          }
+          break;
+        case PACKED:
+          for (int i = 0; i < n; ++i) {
+            if (currentBuffer[currentBufferIdx++] == level) {
+              c.putNotNull(rowId + i);
+              data.readBinary(1, c, rowId);
+            } else {
+              c.putNull(rowId + i);
+            }
+          }
+          break;
+      }
+      rowId += n;
+      left -= n;
+      currentCount -= n;
+    }
+  }
+
+
+  // The RLE reader implements the vectorized decoding interface when used to decode dictionary
+  // IDs. This is different than the above APIs that decodes definitions levels along with
values.
+  // Since this is only used to decode dictionary IDs, only decoding integers is supported.
+  @Override
+  public void readIntegers(int total, ColumnVector c, int rowId) {
+    int left = total;
+    while (left > 0) {
+      if (this.currentCount == 0) this.readNextGroup();
+      int n = Math.min(left, this.currentCount);
+      switch (mode) {
+        case RLE:
+          c.putInts(rowId, n, currentValue);
+          break;
+        case PACKED:
+          c.putInts(rowId, n, currentBuffer, currentBufferIdx);
+          currentBufferIdx += n;
+          break;
+      }
+      rowId += n;
+      left -= n;
+      currentCount -= n;
+    }
+  }
+
+  @Override
+  public byte readByte() {
+    throw new UnsupportedOperationException("only readInts is valid.");
+  }
+
+  @Override
+  public void readBytes(int total, ColumnVector c, int rowId) {
+    throw new UnsupportedOperationException("only readInts is valid.");
+  }
+
+  @Override
+  public void readLongs(int total, ColumnVector c, int rowId) {
+    throw new UnsupportedOperationException("only readInts is valid.");
+  }
+
+  @Override
+  public void readBinary(int total, ColumnVector c, int rowId) {
+    throw new UnsupportedOperationException("only readInts is valid.");
+  }
+
+  @Override
+  public void skip(int n) {
+    throw new UnsupportedOperationException("only readInts is valid.");
+  }
+
+
   /**
    * Reads the next varint encoded int.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/21112e8a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index 49a9ed8..b6ec731 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -24,12 +24,17 @@ import org.apache.spark.sql.execution.vectorized.ColumnVector;
  * TODO: merge this into parquet-mr.
  */
 public interface VectorizedValuesReader {
+  byte readByte();
   int readInteger();
+  long readLong();
 
   /*
    * Reads `total` values into `c` start at `c[rowId]`
    */
+  void readBytes(int total, ColumnVector c, int rowId);
   void readIntegers(int total, ColumnVector c, int rowId);
+  void readLongs(int total, ColumnVector c, int rowId);
+  void readBinary(int total, ColumnVector c, int rowId);
 
   // TODO: add all the other parquet types.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21112e8a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index a5bc506..0514252 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -763,7 +763,12 @@ public abstract class ColumnVector {
   /**
    * Returns the elements appended.
    */
-  public int getElementsAppended() { return elementsAppended; }
+  public final int getElementsAppended() { return elementsAppended; }
+
+  /**
+   * Returns true if this column is an array.
+   */
+  public final boolean isArray() { return resultArray != null; }
 
   /**
    * Maximum number of rows that can be stored in this column.

http://git-wip-us.apache.org/repos/asf/spark/blob/21112e8a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
new file mode 100644
index 0000000..cef6b79
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
+import org.apache.spark.sql.test.SharedSQLContext
+
+// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
+// writer abstractions. Revisit.
+class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
+  import testImplicits._
+
+  val ROW = ((1).toByte, 2, 3L, "abc")
+  val NULL_ROW = (
+    null.asInstanceOf[java.lang.Byte],
+    null.asInstanceOf[Integer],
+    null.asInstanceOf[java.lang.Long],
+    null.asInstanceOf[String])
+
+  test("All Types Dictionary") {
+    (1 :: 1000 :: Nil).foreach { n => {
+      withTempPath { dir =>
+        List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
+        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head
+
+        val reader = new UnsafeRowParquetRecordReader
+        reader.initialize(file.asInstanceOf[String], null)
+        val batch = reader.resultBatch()
+        assert(reader.nextBatch())
+        assert(batch.numRows() == n)
+        var i = 0
+        while (i < n) {
+          assert(batch.column(0).getByte(i) == 1)
+          assert(batch.column(1).getInt(i) == 2)
+          assert(batch.column(2).getLong(i) == 3)
+          assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc")
+          i += 1
+        }
+        reader.close()
+      }
+    }}
+  }
+
+  test("All Types Null") {
+    (1 :: 100 :: Nil).foreach { n => {
+      withTempPath { dir =>
+        val data = List.fill(n)(NULL_ROW).toDF
+        data.repartition(1).write.parquet(dir.getCanonicalPath)
+        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head
+
+        val reader = new UnsafeRowParquetRecordReader
+        reader.initialize(file.asInstanceOf[String], null)
+        val batch = reader.resultBatch()
+        assert(reader.nextBatch())
+        assert(batch.numRows() == n)
+        var i = 0
+        while (i < n) {
+          assert(batch.column(0).getIsNull(i))
+          assert(batch.column(1).getIsNull(i))
+          assert(batch.column(2).getIsNull(i))
+          assert(batch.column(3).getIsNull(i))
+          i += 1
+        }
+        reader.close()
+      }}
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message