Repository: parquet-cpp
Updated Branches:
refs/heads/master 005767c1c -> de35f8efb
PARQUET-929: Handle arrow::DictionaryArray when writing Arrow data
Author: Korn, Uwe <Uwe.Korn@blue-yonder.com>
Closes #393 from xhochy/PARQUET-929 and squashes the following commits:
55d875c [Korn, Uwe] Add missing unit tests
7fcf62b [Korn, Uwe] PARQUET-929: Handle arrow::DictionaryArray when writing Arrow data
Project: http://git-wip-us.apache.org/repos/asf/parquet-cpp/repo
Commit: http://git-wip-us.apache.org/repos/asf/parquet-cpp/commit/de35f8ef
Tree: http://git-wip-us.apache.org/repos/asf/parquet-cpp/tree/de35f8ef
Diff: http://git-wip-us.apache.org/repos/asf/parquet-cpp/diff/de35f8ef
Branch: refs/heads/master
Commit: de35f8efb4fda5e2ef4c3ed3b707361a902f3bff
Parents: 005767c
Author: Korn, Uwe <Uwe.Korn@blue-yonder.com>
Authored: Sun Sep 17 13:55:09 2017 -0400
Committer: Wes McKinney <wes.mckinney@twosigma.com>
Committed: Sun Sep 17 13:55:09 2017 -0400
----------------------------------------------------------------------
src/parquet/arrow/arrow-reader-writer-test.cc | 49 +++++++++++++++--
src/parquet/arrow/arrow-schema-test.cc | 61 ++++++++++++++++++++++
src/parquet/arrow/schema.cc | 20 +++++--
src/parquet/arrow/writer.cc | 19 +++++++
src/parquet/statistics-test.cc | 5 +-
5 files changed, 144 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/parquet-cpp/blob/de35f8ef/src/parquet/arrow/arrow-reader-writer-test.cc
----------------------------------------------------------------------
diff --git a/src/parquet/arrow/arrow-reader-writer-test.cc b/src/parquet/arrow/arrow-reader-writer-test.cc
index c7ac057..a1e3382 100644
--- a/src/parquet/arrow/arrow-reader-writer-test.cc
+++ b/src/parquet/arrow/arrow-reader-writer-test.cc
@@ -39,18 +39,19 @@
#include "arrow/test-util.h"
using arrow::Array;
+using arrow::ArrayVisitor;
using arrow::Buffer;
-using arrow::Column;
using arrow::ChunkedArray;
-using arrow::default_memory_pool;
-using arrow::io::BufferReader;
+using arrow::Column;
+using arrow::EncodeArrayToDictionary;
using arrow::ListArray;
using arrow::PoolBuffer;
using arrow::PrimitiveArray;
using arrow::Status;
using arrow::Table;
using arrow::TimeUnit;
-using arrow::ArrayVisitor;
+using arrow::default_memory_pool;
+using arrow::io::BufferReader;
using ArrowId = ::arrow::Type;
using ParquetType = parquet::Type;
@@ -109,6 +110,11 @@ LogicalType::type get_logical_type(const ::arrow::DataType& type)
{
return LogicalType::TIME_MILLIS;
case ArrowId::TIME64:
return LogicalType::TIME_MICROS;
+ case ArrowId::DICTIONARY: {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(type);
+ return get_logical_type(*dict_type.dictionary()->type());
+ }
default:
break;
}
@@ -150,6 +156,11 @@ ParquetType::type get_physical_type(const ::arrow::DataType& type)
{
return ParquetType::INT64;
case ArrowId::TIMESTAMP:
return ParquetType::INT64;
+ case ArrowId::DICTIONARY: {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(type);
+ return get_physical_type(*dict_type.dictionary()->type());
+ }
default:
break;
}
@@ -331,6 +342,17 @@ static std::shared_ptr<GroupNode> MakeSimpleSchema(const ::arrow::DataType&
type
int byte_width;
// Decimal is not implemented yet.
switch (type.id()) {
+ case ::arrow::Type::DICTIONARY: {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(type);
+ const ::arrow::DataType& values_type = *dict_type.dictionary()->type();
+ if (values_type.id() == ::arrow::Type::FIXED_SIZE_BINARY) {
+ byte_width =
+ static_cast<const ::arrow::FixedSizeBinaryType&>(values_type).byte_width();
+ } else {
+ byte_width = -1;
+ }
+ } break;
case ::arrow::Type::FIXED_SIZE_BINARY:
byte_width = static_cast<const ::arrow::FixedSizeBinaryType&>(type).byte_width();
break;
@@ -509,6 +531,25 @@ TYPED_TEST(TestParquetIO, SingleColumnOptionalReadWrite) {
this->ReadAndCheckSingleColumnFile(values.get());
}
+TYPED_TEST(TestParquetIO, SingleColumnOptionalDictionaryWrite) {
+ // Skip tests for BOOL as we don't create dictionaries for it.
+ if (TypeParam::type_id == ::arrow::Type::BOOL) {
+ return;
+ }
+
+ std::shared_ptr<Array> values;
+
+ ASSERT_OK(NullableArray<TypeParam>(SMALL_SIZE, 10, kDefaultSeed, &values));
+
+ std::shared_ptr<Array> dict_values;
+ ASSERT_OK(EncodeArrayToDictionary(*values, default_memory_pool(), &dict_values));
+ std::shared_ptr<GroupNode> schema =
+ MakeSimpleSchema(*dict_values->type(), Repetition::OPTIONAL);
+ this->WriteColumn(schema, dict_values);
+
+ this->ReadAndCheckSingleColumnFile(values.get());
+}
+
TYPED_TEST(TestParquetIO, SingleColumnRequiredSliceWrite) {
std::shared_ptr<Array> values;
ASSERT_OK(NonNullArray<TypeParam>(2 * SMALL_SIZE, &values));
http://git-wip-us.apache.org/repos/asf/parquet-cpp/blob/de35f8ef/src/parquet/arrow/arrow-schema-test.cc
----------------------------------------------------------------------
diff --git a/src/parquet/arrow/arrow-schema-test.cc b/src/parquet/arrow/arrow-schema-test.cc
index 22e3adb..a7a62c5 100644
--- a/src/parquet/arrow/arrow-schema-test.cc
+++ b/src/parquet/arrow/arrow-schema-test.cc
@@ -25,6 +25,7 @@
#include "arrow/api.h"
#include "arrow/test-util.h"
+using arrow::ArrayFromVector;
using arrow::Field;
using arrow::TimeUnit;
@@ -681,6 +682,66 @@ TEST_F(TestConvertArrowSchema, ParquetFlatPrimitives) {
CheckFlatSchema(parquet_fields);
}
+TEST_F(TestConvertArrowSchema, ParquetFlatPrimitivesAsDictionaries) {
+ std::vector<NodePtr> parquet_fields;
+ std::vector<std::shared_ptr<Field>> arrow_fields;
+ std::shared_ptr<::arrow::Array> dict;
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int32", Repetition::REQUIRED, ParquetType::INT32));
+ ArrayFromVector<::arrow::Int32Type, int32_t>(std::vector<int32_t>(), &dict);
+ arrow_fields.push_back(
+ ::arrow::field("int32", ::arrow::dictionary(::arrow::int8(), dict), false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("int64", Repetition::REQUIRED, ParquetType::INT64));
+ ArrayFromVector<::arrow::Int64Type, int64_t>(std::vector<int64_t>(), &dict);
+ arrow_fields.push_back(std::make_shared<Field>(
+ "int64", ::arrow::dictionary(::arrow::int8(), dict), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date", Repetition::REQUIRED,
+ ParquetType::INT32, LogicalType::DATE));
+ ArrayFromVector<::arrow::Date32Type, int32_t>(std::vector<int32_t>(), &dict);
+ arrow_fields.push_back(
+ std::make_shared<Field>("date", ::arrow::dictionary(::arrow::int8(), dict), false));
+
+ parquet_fields.push_back(PrimitiveNode::Make("date64", Repetition::REQUIRED,
+ ParquetType::INT32, LogicalType::DATE));
+ ArrayFromVector<::arrow::Date64Type, int64_t>(std::vector<int64_t>(), &dict);
+ arrow_fields.push_back(std::make_shared<Field>(
+ "date64", ::arrow::dictionary(::arrow::int8(), dict), false));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("float", Repetition::OPTIONAL, ParquetType::FLOAT));
+ ArrayFromVector<::arrow::FloatType, float>(std::vector<float>(), &dict);
+ arrow_fields.push_back(
+ std::make_shared<Field>("float", ::arrow::dictionary(::arrow::int8(), dict)));
+
+ parquet_fields.push_back(
+ PrimitiveNode::Make("double", Repetition::OPTIONAL, ParquetType::DOUBLE));
+ ArrayFromVector<::arrow::DoubleType, double>(std::vector<double>(), &dict);
+ arrow_fields.push_back(
+ std::make_shared<Field>("double", ::arrow::dictionary(::arrow::int8(), dict)));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "string", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, LogicalType::UTF8));
+ ::arrow::StringBuilder string_builder(::arrow::default_memory_pool());
+ ASSERT_OK(string_builder.Finish(&dict));
+ arrow_fields.push_back(
+ std::make_shared<Field>("string", ::arrow::dictionary(::arrow::int8(), dict)));
+
+ parquet_fields.push_back(PrimitiveNode::Make(
+ "binary", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY, LogicalType::NONE));
+ ::arrow::BinaryBuilder binary_builder(::arrow::default_memory_pool());
+ ASSERT_OK(binary_builder.Finish(&dict));
+ arrow_fields.push_back(
+ std::make_shared<Field>("binary", ::arrow::dictionary(::arrow::int8(), dict)));
+
+ ASSERT_OK(ConvertSchema(arrow_fields));
+
+ CheckFlatSchema(parquet_fields);
+}
+
TEST_F(TestConvertArrowSchema, ParquetLists) {
std::vector<NodePtr> parquet_fields;
std::vector<std::shared_ptr<Field>> arrow_fields;
http://git-wip-us.apache.org/repos/asf/parquet-cpp/blob/de35f8ef/src/parquet/arrow/schema.cc
----------------------------------------------------------------------
diff --git a/src/parquet/arrow/schema.cc b/src/parquet/arrow/schema.cc
index 366f2c3..87e5b38 100644
--- a/src/parquet/arrow/schema.cc
+++ b/src/parquet/arrow/schema.cc
@@ -573,9 +573,23 @@ Status FieldToNode(const std::shared_ptr<Field>& field,
return ListToNode(list_type, field->name(), field->nullable(), properties,
arrow_properties, out);
} break;
- default:
- // TODO: LIST, DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL, DECIMAL_TEXT, VARCHAR
- return Status::NotImplemented("unhandled type");
+ case ArrowType::DICTIONARY: {
+ // Parquet has no Dictionary type, dictionary-encoded is handled on
+ // the encoding, not the schema level.
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(*field->type());
+ std::shared_ptr<::arrow::Field> unpacked_field =
+ ::arrow::field(field->name(), dict_type.dictionary()->type(), field->nullable(),
+ field->metadata());
+ return FieldToNode(unpacked_field, properties, arrow_properties, out);
+ }
+ default: {
+ // TODO: DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL, DECIMAL_TEXT, VARCHAR
+ std::stringstream ss;
+ ss << "Unhandled type for Arrow to Parquet schema conversion: ";
+ ss << field->type()->ToString();
+ return Status::NotImplemented(ss.str());
+ }
}
*out = PrimitiveNode::Make(field->name(), repetition, type, logical_type, length);
return Status::OK();
http://git-wip-us.apache.org/repos/asf/parquet-cpp/blob/de35f8ef/src/parquet/arrow/writer.cc
----------------------------------------------------------------------
diff --git a/src/parquet/arrow/writer.cc b/src/parquet/arrow/writer.cc
index 52d0353..55af292 100644
--- a/src/parquet/arrow/writer.cc
+++ b/src/parquet/arrow/writer.cc
@@ -22,6 +22,7 @@
#include <vector>
#include "arrow/api.h"
+#include "arrow/compute/api.h"
#include "arrow/util/bit-util.h"
#include "arrow/visitor_inline.h"
@@ -44,6 +45,10 @@ using arrow::Status;
using arrow::Table;
using arrow::TimeUnit;
+using arrow::compute::Cast;
+using arrow::compute::CastOptions;
+using arrow::compute::FunctionContext;
+
using parquet::ParquetFileWriter;
using parquet::ParquetVersion;
using parquet::schema::GroupNode;
@@ -798,6 +803,20 @@ Status FileWriter::NewRowGroup(int64_t chunk_size) {
}
Status FileWriter::Impl::WriteColumnChunk(const Array& data) {
+ // DictionaryArrays are not yet handled with a fast path. To still support
+ // writing them as a workaround, we convert them back to their non-dictionary
+ // representation.
+ if (data.type()->id() == ::arrow::Type::DICTIONARY) {
+ const ::arrow::DictionaryType& dict_type =
+ static_cast<const ::arrow::DictionaryType&>(*data.type());
+
+ FunctionContext ctx(pool_);
+ std::shared_ptr<Array> plain_array;
+ RETURN_NOT_OK(
+ Cast(&ctx, data, dict_type.dictionary()->type(), CastOptions(), &plain_array));
+ return WriteColumnChunk(*plain_array);
+ }
+
ColumnWriter* column_writer;
PARQUET_CATCH_NOT_OK(column_writer = row_group_writer_->NextColumn());
http://git-wip-us.apache.org/repos/asf/parquet-cpp/blob/de35f8ef/src/parquet/statistics-test.cc
----------------------------------------------------------------------
diff --git a/src/parquet/statistics-test.cc b/src/parquet/statistics-test.cc
index 1521cbd..d3ec942 100644
--- a/src/parquet/statistics-test.cc
+++ b/src/parquet/statistics-test.cc
@@ -194,9 +194,8 @@ bool* TestRowGroupStatistics<BooleanType>::GetValuesPointer(std::vector<bool>&
v
}
template <typename TestType>
-typename std::vector<typename TestType::c_type>
-TestRowGroupStatistics<TestType>::GetDeepCopy(
- const std::vector<typename TestType::c_type>& values) {
+typename std::vector<typename TestType::c_type> TestRowGroupStatistics<
+ TestType>::GetDeepCopy(const std::vector<typename TestType::c_type>& values)
{
return values;
}
|