Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 79245200C28 for ; Mon, 27 Feb 2017 00:25:11 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 7477B160B77; Sun, 26 Feb 2017 23:25:11 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id F3F05160B6E for ; Mon, 27 Feb 2017 00:25:09 +0100 (CET) Received: (qmail 25409 invoked by uid 500); 26 Feb 2017 23:25:09 -0000 Mailing-List: contact commits-help@arrow.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@arrow.apache.org Delivered-To: mailing list commits@arrow.apache.org Received: (qmail 25399 invoked by uid 99); 26 Feb 2017 23:25:09 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 26 Feb 2017 23:25:09 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 0EEFCDFDA9; Sun, 26 Feb 2017 23:25:09 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: wesm@apache.org To: commits@arrow.apache.org Message-Id: <35ad957c5fdc40abb97f972cabaef51a@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: arrow git commit: ARROW-451: [C++] Implement DataType::Equals as TypeVisitor. Add default implementations for TypeVisitor, ArrayVisitor methods Date: Sun, 26 Feb 2017 23:25:09 +0000 (UTC) archived-at: Sun, 26 Feb 2017 23:25:11 -0000 Repository: arrow Updated Branches: refs/heads/master 8afe92c6c -> ef3b6b344 ARROW-451: [C++] Implement DataType::Equals as TypeVisitor. Add default implementations for TypeVisitor, ArrayVisitor methods This patch also resolves ARROW-568. Added tests for TimeType, TimestampType, which were not having their `unit` metadata compared due to an oversight. Author: Wes McKinney Closes #350 from wesm/ARROW-451 and squashes the following commits: 97e75d8 [Wes McKinney] Export ArrayVisitor, TypeVisitor symbols a3332be [Wes McKinney] Typo 635e74d [Wes McKinney] Implement DataType::Equals as TypeVisitor, compare child metadata. Add default implementations for TypeVisitor, ArrayVisitor methods Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/ef3b6b34 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/ef3b6b34 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/ef3b6b34 Branch: refs/heads/master Commit: ef3b6b34482c36615af5064f474363126e755a18 Parents: 8afe92c Author: Wes McKinney Authored: Sun Feb 26 18:25:03 2017 -0500 Committer: Wes McKinney Committed: Sun Feb 26 18:25:03 2017 -0500 ---------------------------------------------------------------------- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/array.cc | 36 ++++++++ cpp/src/arrow/array.h | 50 +++++------ cpp/src/arrow/compare.cc | 108 +++++++++++++++++++++-- cpp/src/arrow/compare.h | 5 ++ cpp/src/arrow/ipc/adapter.cc | 20 ----- cpp/src/arrow/ipc/json-internal.cc | 30 ------- cpp/src/arrow/schema-test.cc | 122 -------------------------- cpp/src/arrow/type-test.cc | 146 ++++++++++++++++++++++++++++++++ cpp/src/arrow/type.cc | 69 +++++++++++---- cpp/src/arrow/type.h | 64 ++++++-------- 11 files changed, 394 insertions(+), 258 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 824ced1..d1efa02 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -58,8 +58,8 @@ ADD_ARROW_TEST(buffer-test) ADD_ARROW_TEST(column-test) ADD_ARROW_TEST(memory_pool-test) ADD_ARROW_TEST(pretty_print-test) -ADD_ARROW_TEST(schema-test) ADD_ARROW_TEST(status-test) +ADD_ARROW_TEST(type-test) ADD_ARROW_TEST(table-test) ADD_ARROW_BENCHMARK(builder-benchmark) http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/array.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 81678e3..eb4c210 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -503,4 +503,40 @@ Status MakePrimitiveArray(const std::shared_ptr& type, int32_t length, #endif } +// ---------------------------------------------------------------------- +// Default implementations of ArrayVisitor methods + +#define ARRAY_VISITOR_DEFAULT(ARRAY_CLASS) \ + Status ArrayVisitor::Visit(const ARRAY_CLASS& array) { \ + return Status::NotImplemented(array.type()->ToString()); \ + } + +ARRAY_VISITOR_DEFAULT(NullArray); +ARRAY_VISITOR_DEFAULT(BooleanArray); +ARRAY_VISITOR_DEFAULT(Int8Array); +ARRAY_VISITOR_DEFAULT(Int16Array); +ARRAY_VISITOR_DEFAULT(Int32Array); +ARRAY_VISITOR_DEFAULT(Int64Array); +ARRAY_VISITOR_DEFAULT(UInt8Array); +ARRAY_VISITOR_DEFAULT(UInt16Array); +ARRAY_VISITOR_DEFAULT(UInt32Array); +ARRAY_VISITOR_DEFAULT(UInt64Array); +ARRAY_VISITOR_DEFAULT(HalfFloatArray); +ARRAY_VISITOR_DEFAULT(FloatArray); +ARRAY_VISITOR_DEFAULT(DoubleArray); +ARRAY_VISITOR_DEFAULT(StringArray); +ARRAY_VISITOR_DEFAULT(BinaryArray); +ARRAY_VISITOR_DEFAULT(DateArray); +ARRAY_VISITOR_DEFAULT(TimeArray); +ARRAY_VISITOR_DEFAULT(TimestampArray); +ARRAY_VISITOR_DEFAULT(IntervalArray); +ARRAY_VISITOR_DEFAULT(ListArray); +ARRAY_VISITOR_DEFAULT(StructArray); +ARRAY_VISITOR_DEFAULT(UnionArray); +ARRAY_VISITOR_DEFAULT(DictionaryArray); + +Status ArrayVisitor::Visit(const DecimalArray& array) { + return Status::NotImplemented("decimal"); +} + } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/array.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 9bb06af..8bb914e 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -38,34 +38,34 @@ class MemoryPool; class MutableBuffer; class Status; -class ArrayVisitor { +class ARROW_EXPORT ArrayVisitor { public: virtual ~ArrayVisitor() = default; - virtual Status Visit(const NullArray& array) = 0; - virtual Status Visit(const BooleanArray& array) = 0; - virtual Status Visit(const Int8Array& array) = 0; - virtual Status Visit(const Int16Array& array) = 0; - virtual Status Visit(const Int32Array& array) = 0; - virtual Status Visit(const Int64Array& array) = 0; - virtual Status Visit(const UInt8Array& array) = 0; - virtual Status Visit(const UInt16Array& array) = 0; - virtual Status Visit(const UInt32Array& array) = 0; - virtual Status Visit(const UInt64Array& array) = 0; - virtual Status Visit(const HalfFloatArray& array) = 0; - virtual Status Visit(const FloatArray& array) = 0; - virtual Status Visit(const DoubleArray& array) = 0; - virtual Status Visit(const StringArray& array) = 0; - virtual Status Visit(const BinaryArray& array) = 0; - virtual Status Visit(const DateArray& array) = 0; - virtual Status Visit(const TimeArray& array) = 0; - virtual Status Visit(const TimestampArray& array) = 0; - virtual Status Visit(const IntervalArray& array) = 0; - virtual Status Visit(const DecimalArray& array) = 0; - virtual Status Visit(const ListArray& array) = 0; - virtual Status Visit(const StructArray& array) = 0; - virtual Status Visit(const UnionArray& array) = 0; - virtual Status Visit(const DictionaryArray& type) = 0; + virtual Status Visit(const NullArray& array); + virtual Status Visit(const BooleanArray& array); + virtual Status Visit(const Int8Array& array); + virtual Status Visit(const Int16Array& array); + virtual Status Visit(const Int32Array& array); + virtual Status Visit(const Int64Array& array); + virtual Status Visit(const UInt8Array& array); + virtual Status Visit(const UInt16Array& array); + virtual Status Visit(const UInt32Array& array); + virtual Status Visit(const UInt64Array& array); + virtual Status Visit(const HalfFloatArray& array); + virtual Status Visit(const FloatArray& array); + virtual Status Visit(const DoubleArray& array); + virtual Status Visit(const StringArray& array); + virtual Status Visit(const BinaryArray& array); + virtual Status Visit(const DateArray& array); + virtual Status Visit(const TimeArray& array); + virtual Status Visit(const TimestampArray& array); + virtual Status Visit(const IntervalArray& array); + virtual Status Visit(const DecimalArray& array); + virtual Status Visit(const ListArray& array); + virtual Status Visit(const StructArray& array); + virtual Status Visit(const UnionArray& array); + virtual Status Visit(const DictionaryArray& type); }; /// Immutable data array with some logical type and some length. http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/compare.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 21fdb66..ff3c59f 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -301,9 +301,9 @@ class RangeEqualsVisitor : public ArrayVisitor { bool result_; }; -class EqualsVisitor : public RangeEqualsVisitor { +class ArrayEqualsVisitor : public RangeEqualsVisitor { public: - explicit EqualsVisitor(const Array& right) + explicit ArrayEqualsVisitor(const Array& right) : RangeEqualsVisitor(right, 0, right.length(), 0) {} Status Visit(const NullArray& left) override { return Status::OK(); } @@ -511,9 +511,9 @@ inline bool FloatingApproxEquals( return true; } -class ApproxEqualsVisitor : public EqualsVisitor { +class ApproxEqualsVisitor : public ArrayEqualsVisitor { public: - using EqualsVisitor::EqualsVisitor; + using ArrayEqualsVisitor::ArrayEqualsVisitor; Status Visit(const FloatArray& left) override { result_ = @@ -549,7 +549,7 @@ Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) { } else if (left.length() == 0) { *are_equal = true; } else { - EqualsVisitor visitor(right); + ArrayEqualsVisitor visitor(right); RETURN_NOT_OK(left.Accept(&visitor)); *are_equal = visitor.result(); } @@ -588,4 +588,102 @@ Status ArrayApproxEquals(const Array& left, const Array& right, bool* are_equal) return Status::OK(); } +// ---------------------------------------------------------------------- +// Implement TypeEquals + +class TypeEqualsVisitor : public TypeVisitor { + public: + explicit TypeEqualsVisitor(const DataType& right) : right_(right), result_(false) {} + + Status VisitChildren(const DataType& left) { + if (left.num_children() != right_.num_children()) { + result_ = false; + return Status::OK(); + } + + for (int i = 0; i < left.num_children(); ++i) { + if (!left.child(i)->Equals(right_.child(i))) { + result_ = false; + break; + } + } + result_ = true; + return Status::OK(); + } + + Status Visit(const TimeType& left) override { + const auto& right = static_cast(right_); + result_ = left.unit == right.unit; + return Status::OK(); + } + + Status Visit(const TimestampType& left) override { + const auto& right = static_cast(right_); + result_ = left.unit == right.unit; + return Status::OK(); + } + + Status Visit(const ListType& left) override { return VisitChildren(left); } + + Status Visit(const StructType& left) override { return VisitChildren(left); } + + Status Visit(const UnionType& left) override { + const auto& right = static_cast(right_); + + if (left.mode != right.mode || left.type_codes.size() != right.type_codes.size()) { + result_ = false; + return Status::OK(); + } + + const std::vector left_codes = left.type_codes; + const std::vector right_codes = right.type_codes; + + for (size_t i = 0; i < left_codes.size(); ++i) { + if (left_codes[i] != right_codes[i]) { + result_ = false; + break; + } + } + result_ = true; + return Status::OK(); + } + + Status Visit(const DictionaryType& left) override { + const auto& right = static_cast(right_); + result_ = left.index_type()->Equals(right.index_type()) && + left.dictionary()->Equals(right.dictionary()); + return Status::OK(); + } + + bool result() const { return result_; } + + protected: + const DataType& right_; + bool result_; +}; + +Status TypeEquals(const DataType& left, const DataType& right, bool* are_equal) { + // The arrays are the same object + if (&left == &right) { + *are_equal = true; + } else if (left.type != right.type) { + *are_equal = false; + } else { + TypeEqualsVisitor visitor(right); + Status s = left.Accept(&visitor); + + // We do not implement any type visitors where there is no additional + // metadata to compare. + if (s.IsNotImplemented()) { + // Not implemented means there is no additional metadata to compare + *are_equal = true; + } else if (!s.ok()) { + return s; + } else { + *are_equal = visitor.result(); + } + } + return Status::OK(); +} + } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/compare.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index 2093b65..6a71f9f 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -27,6 +27,7 @@ namespace arrow { class Array; +struct DataType; class Status; /// Returns true if the arrays are exactly equal @@ -41,6 +42,10 @@ Status ARROW_EXPORT ArrayApproxEquals( Status ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right, int32_t start_idx, int32_t end_idx, int32_t other_start_idx, bool* are_equal); +/// Returns true if the type metadata are exactly equal +Status ARROW_EXPORT TypeEquals( + const DataType& left, const DataType& right, bool* are_equal); + } // namespace arrow #endif // ARROW_COMPARE_H http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/ipc/adapter.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index 08ac983..2be87a3 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -227,8 +227,6 @@ class RecordBatchWriter : public ArrayVisitor { } protected: - Status Visit(const NullArray& array) override { return Status::NotImplemented("null"); } - template Status VisitFixedWidth(const ArrayType& array) { std::shared_ptr data_buffer = array.data(); @@ -360,14 +358,6 @@ class RecordBatchWriter : public ArrayVisitor { return VisitFixedWidth(array); } - Status Visit(const IntervalArray& array) override { - return Status::NotImplemented("interval"); - } - - Status Visit(const DecimalArray& array) override { - return Status::NotImplemented("decimal"); - } - Status Visit(const ListArray& array) override { std::shared_ptr value_offsets; RETURN_NOT_OK(GetZeroBasedValueOffsets(array, &value_offsets)); @@ -653,8 +643,6 @@ class ArrayLoader : public TypeVisitor { return Status::OK(); } - Status Visit(const NullType& type) override { return Status::NotImplemented("null"); } - Status Visit(const BooleanType& type) override { return LoadPrimitive(type); } Status Visit(const Int8Type& type) override { return LoadPrimitive(type); } @@ -689,14 +677,6 @@ class ArrayLoader : public TypeVisitor { Status Visit(const TimestampType& type) override { return LoadPrimitive(type); } - Status Visit(const IntervalType& type) override { - return Status::NotImplemented(type.ToString()); - } - - Status Visit(const DecimalType& type) override { - return Status::NotImplemented(type.ToString()); - } - Status Visit(const ListType& type) override { FieldMetadata field_meta; std::shared_ptr null_bitmap, offsets; http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/ipc/json-internal.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index b9f97dd..6253cd6 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -316,8 +316,6 @@ class JsonSchemaWriter : public TypeVisitor { return WritePrimitive("interval", type); } - Status Visit(const DecimalType& type) override { return Status::NotImplemented("NYI"); } - Status Visit(const ListType& type) override { WriteName("list", type); RETURN_NOT_OK(WriteChildren(type.children())); @@ -339,14 +337,6 @@ class JsonSchemaWriter : public TypeVisitor { return Status::OK(); } - Status Visit(const DictionaryType& type) override { - // WriteName("dictionary", type); - // WriteChildren(type.children()); - // WriteBufferLayout(type.GetBufferLayout()); - // return Status::OK(); - return Status::NotImplemented("dictionary type"); - } - private: const Schema& schema_; RjWriter* writer_; @@ -531,22 +521,6 @@ class JsonArrayWriter : public ArrayVisitor { Status Visit(const BinaryArray& array) override { return WriteVarBytes(array); } - Status Visit(const DateArray& array) override { return Status::NotImplemented("date"); } - - Status Visit(const TimeArray& array) override { return Status::NotImplemented("time"); } - - Status Visit(const TimestampArray& array) override { - return Status::NotImplemented("timestamp"); - } - - Status Visit(const IntervalArray& array) override { - return Status::NotImplemented("interval"); - } - - Status Visit(const DecimalArray& array) override { - return Status::NotImplemented("decimal"); - } - Status Visit(const ListArray& array) override { WriteValidityField(array); WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length() + 1); @@ -571,10 +545,6 @@ class JsonArrayWriter : public ArrayVisitor { return WriteChildren(type->children(), array.children()); } - Status Visit(const DictionaryArray& array) override { - return Status::NotImplemented("dictionary"); - } - private: const std::string& name_; const Array& array_; http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/schema-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/schema-test.cc b/cpp/src/arrow/schema-test.cc deleted file mode 100644 index 4826199..0000000 --- a/cpp/src/arrow/schema-test.cc +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include - -#include "gtest/gtest.h" - -#include "arrow/schema.h" -#include "arrow/type.h" - -using std::shared_ptr; -using std::vector; - -namespace arrow { - -TEST(TestField, Basics) { - Field f0("f0", int32()); - Field f0_nn("f0", int32(), false); - - ASSERT_EQ(f0.name, "f0"); - ASSERT_EQ(f0.type->ToString(), int32()->ToString()); - - ASSERT_TRUE(f0.nullable); - ASSERT_FALSE(f0_nn.nullable); -} - -TEST(TestField, Equals) { - Field f0("f0", int32()); - Field f0_nn("f0", int32(), false); - Field f0_other("f0", int32()); - - ASSERT_EQ(f0, f0_other); - ASSERT_NE(f0, f0_nn); -} - -class TestSchema : public ::testing::Test { - public: - void SetUp() {} -}; - -TEST_F(TestSchema, Basics) { - auto f0 = field("f0", int32()); - auto f1 = field("f1", uint8(), false); - auto f1_optional = field("f1", uint8()); - - auto f2 = field("f2", utf8()); - - vector> fields = {f0, f1, f2}; - auto schema = std::make_shared(fields); - - ASSERT_EQ(3, schema->num_fields()); - ASSERT_EQ(f0, schema->field(0)); - ASSERT_EQ(f1, schema->field(1)); - ASSERT_EQ(f2, schema->field(2)); - - auto schema2 = std::make_shared(fields); - - vector> fields3 = {f0, f1_optional, f2}; - auto schema3 = std::make_shared(fields3); - ASSERT_TRUE(schema->Equals(schema2)); - ASSERT_FALSE(schema->Equals(schema3)); - - ASSERT_TRUE(schema->Equals(*schema2.get())); - ASSERT_FALSE(schema->Equals(*schema3.get())); -} - -TEST_F(TestSchema, ToString) { - auto f0 = field("f0", int32()); - auto f1 = field("f1", uint8(), false); - auto f2 = field("f2", utf8()); - auto f3 = field("f3", list(int16())); - - vector> fields = {f0, f1, f2, f3}; - auto schema = std::make_shared(fields); - - std::string result = schema->ToString(); - std::string expected = R"(f0: int32 -f1: uint8 not null -f2: string -f3: list)"; - - ASSERT_EQ(expected, result); -} - -TEST_F(TestSchema, GetFieldByName) { - auto f0 = field("f0", int32()); - auto f1 = field("f1", uint8(), false); - auto f2 = field("f2", utf8()); - auto f3 = field("f3", list(int16())); - - vector> fields = {f0, f1, f2, f3}; - auto schema = std::make_shared(fields); - - std::shared_ptr result; - - result = schema->GetFieldByName("f1"); - ASSERT_TRUE(f1->Equals(result)); - - result = schema->GetFieldByName("f3"); - ASSERT_TRUE(f3->Equals(result)); - - result = schema->GetFieldByName("not-found"); - ASSERT_TRUE(result == nullptr); -} - -} // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc new file mode 100644 index 0000000..fe6c62a --- /dev/null +++ b/cpp/src/arrow/type-test.cc @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Unit tests for DataType (and subclasses), Field, and Schema + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "arrow/schema.h" +#include "arrow/type.h" + +using std::shared_ptr; +using std::vector; + +namespace arrow { + +TEST(TestField, Basics) { + Field f0("f0", int32()); + Field f0_nn("f0", int32(), false); + + ASSERT_EQ(f0.name, "f0"); + ASSERT_EQ(f0.type->ToString(), int32()->ToString()); + + ASSERT_TRUE(f0.nullable); + ASSERT_FALSE(f0_nn.nullable); +} + +TEST(TestField, Equals) { + Field f0("f0", int32()); + Field f0_nn("f0", int32(), false); + Field f0_other("f0", int32()); + + ASSERT_TRUE(f0.Equals(f0_other)); + ASSERT_FALSE(f0.Equals(f0_nn)); +} + +class TestSchema : public ::testing::Test { + public: + void SetUp() {} +}; + +TEST_F(TestSchema, Basics) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f1_optional = field("f1", uint8()); + + auto f2 = field("f2", utf8()); + + vector> fields = {f0, f1, f2}; + auto schema = std::make_shared(fields); + + ASSERT_EQ(3, schema->num_fields()); + ASSERT_TRUE(f0->Equals(schema->field(0))); + ASSERT_TRUE(f1->Equals(schema->field(1))); + ASSERT_TRUE(f2->Equals(schema->field(2))); + + auto schema2 = std::make_shared(fields); + + vector> fields3 = {f0, f1_optional, f2}; + auto schema3 = std::make_shared(fields3); + ASSERT_TRUE(schema->Equals(schema2)); + ASSERT_FALSE(schema->Equals(schema3)); + + ASSERT_TRUE(schema->Equals(*schema2.get())); + ASSERT_FALSE(schema->Equals(*schema3.get())); +} + +TEST_F(TestSchema, ToString) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f3", list(int16())); + + vector> fields = {f0, f1, f2, f3}; + auto schema = std::make_shared(fields); + + std::string result = schema->ToString(); + std::string expected = R"(f0: int32 +f1: uint8 not null +f2: string +f3: list)"; + + ASSERT_EQ(expected, result); +} + +TEST_F(TestSchema, GetFieldByName) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f3", list(int16())); + + vector> fields = {f0, f1, f2, f3}; + auto schema = std::make_shared(fields); + + std::shared_ptr result; + + result = schema->GetFieldByName("f1"); + ASSERT_TRUE(f1->Equals(result)); + + result = schema->GetFieldByName("f3"); + ASSERT_TRUE(f3->Equals(result)); + + result = schema->GetFieldByName("not-found"); + ASSERT_TRUE(result == nullptr); +} + +TEST(TestTimeType, Equals) { + TimeType t1; + TimeType t2; + TimeType t3(TimeUnit::NANO); + TimeType t4(TimeUnit::NANO); + + ASSERT_TRUE(t1.Equals(t2)); + ASSERT_FALSE(t1.Equals(t3)); + ASSERT_TRUE(t3.Equals(t4)); +} + +TEST(TestTimestampType, Equals) { + TimestampType t1; + TimestampType t2; + TimestampType t3(TimeUnit::NANO); + TimestampType t4(TimeUnit::NANO); + + ASSERT_TRUE(t1.Equals(t2)); + ASSERT_FALSE(t1.Equals(t3)); + ASSERT_TRUE(t3.Equals(t4)); +} + +} // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index b97b465..23fa681 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -21,6 +21,7 @@ #include #include "arrow/array.h" +#include "arrow/compare.h" #include "arrow/status.h" #include "arrow/util/logging.h" @@ -46,16 +47,14 @@ std::string Field::ToString() const { DataType::~DataType() {} bool DataType::Equals(const DataType& other) const { - bool equals = - ((this == &other) || ((this->type == other.type) && - ((this->num_children() == other.num_children())))); - if (equals) { - for (int i = 0; i < num_children(); ++i) { - // TODO(emkornfield) limit recursion - if (!children_[i]->Equals(other.children_[i])) { return false; } - } - } - return equals; + bool are_equal = false; + Status error = TypeEquals(*this, other, &are_equal); + if (!error.ok()) { DCHECK(false) << "Types not comparable: " << error.ToString(); } + return are_equal; +} + +bool DataType::Equals(const std::shared_ptr& other) const { + return Equals(*other.get()); } std::string BooleanType::ToString() const { @@ -104,6 +103,15 @@ std::string DateType::ToString() const { return std::string("date"); } +// ---------------------------------------------------------------------- +// Union type + +UnionType::UnionType(const std::vector>& fields, + const std::vector& type_codes, UnionMode mode) + : DataType(Type::UNION), mode(mode), type_codes(type_codes) { + children_ = fields; +} + std::string UnionType::ToString() const { std::stringstream s; @@ -138,14 +146,6 @@ std::shared_ptr DictionaryType::dictionary() const { return dictionary_; } -bool DictionaryType::Equals(const DataType& other) const { - if (other.type != Type::DICTIONARY) { return false; } - const auto& other_dict = static_cast(other); - - return index_type_->Equals(other_dict.index_type_) && - dictionary_->Equals(other_dict.dictionary_); -} - std::string DictionaryType::ToString() const { std::stringstream ss; ss << "dictionarytype()->ToString() @@ -286,4 +286,37 @@ std::vector DecimalType::GetBufferLayout() const { return {}; } +// ---------------------------------------------------------------------- +// Default implementations of TypeVisitor methods + +#define TYPE_VISITOR_DEFAULT(TYPE_CLASS) \ + Status TypeVisitor::Visit(const TYPE_CLASS& type) { \ + return Status::NotImplemented(type.ToString()); \ + } + +TYPE_VISITOR_DEFAULT(NullType); +TYPE_VISITOR_DEFAULT(BooleanType); +TYPE_VISITOR_DEFAULT(Int8Type); +TYPE_VISITOR_DEFAULT(Int16Type); +TYPE_VISITOR_DEFAULT(Int32Type); +TYPE_VISITOR_DEFAULT(Int64Type); +TYPE_VISITOR_DEFAULT(UInt8Type); +TYPE_VISITOR_DEFAULT(UInt16Type); +TYPE_VISITOR_DEFAULT(UInt32Type); +TYPE_VISITOR_DEFAULT(UInt64Type); +TYPE_VISITOR_DEFAULT(HalfFloatType); +TYPE_VISITOR_DEFAULT(FloatType); +TYPE_VISITOR_DEFAULT(DoubleType); +TYPE_VISITOR_DEFAULT(StringType); +TYPE_VISITOR_DEFAULT(BinaryType); +TYPE_VISITOR_DEFAULT(DateType); +TYPE_VISITOR_DEFAULT(TimeType); +TYPE_VISITOR_DEFAULT(TimestampType); +TYPE_VISITOR_DEFAULT(IntervalType); +TYPE_VISITOR_DEFAULT(DecimalType); +TYPE_VISITOR_DEFAULT(ListType); +TYPE_VISITOR_DEFAULT(StructType); +TYPE_VISITOR_DEFAULT(UnionType); +TYPE_VISITOR_DEFAULT(DictionaryType); + } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index b15aa27..9a97fc3 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -112,34 +112,34 @@ class BufferDescr { int bit_width_; }; -class TypeVisitor { +class ARROW_EXPORT TypeVisitor { public: virtual ~TypeVisitor() = default; - virtual Status Visit(const NullType& type) = 0; - virtual Status Visit(const BooleanType& type) = 0; - virtual Status Visit(const Int8Type& type) = 0; - virtual Status Visit(const Int16Type& type) = 0; - virtual Status Visit(const Int32Type& type) = 0; - virtual Status Visit(const Int64Type& type) = 0; - virtual Status Visit(const UInt8Type& type) = 0; - virtual Status Visit(const UInt16Type& type) = 0; - virtual Status Visit(const UInt32Type& type) = 0; - virtual Status Visit(const UInt64Type& type) = 0; - virtual Status Visit(const HalfFloatType& type) = 0; - virtual Status Visit(const FloatType& type) = 0; - virtual Status Visit(const DoubleType& type) = 0; - virtual Status Visit(const StringType& type) = 0; - virtual Status Visit(const BinaryType& type) = 0; - virtual Status Visit(const DateType& type) = 0; - virtual Status Visit(const TimeType& type) = 0; - virtual Status Visit(const TimestampType& type) = 0; - virtual Status Visit(const IntervalType& type) = 0; - virtual Status Visit(const DecimalType& type) = 0; - virtual Status Visit(const ListType& type) = 0; - virtual Status Visit(const StructType& type) = 0; - virtual Status Visit(const UnionType& type) = 0; - virtual Status Visit(const DictionaryType& type) = 0; + virtual Status Visit(const NullType& type); + virtual Status Visit(const BooleanType& type); + virtual Status Visit(const Int8Type& type); + virtual Status Visit(const Int16Type& type); + virtual Status Visit(const Int32Type& type); + virtual Status Visit(const Int64Type& type); + virtual Status Visit(const UInt8Type& type); + virtual Status Visit(const UInt16Type& type); + virtual Status Visit(const UInt32Type& type); + virtual Status Visit(const UInt64Type& type); + virtual Status Visit(const HalfFloatType& type); + virtual Status Visit(const FloatType& type); + virtual Status Visit(const DoubleType& type); + virtual Status Visit(const StringType& type); + virtual Status Visit(const BinaryType& type); + virtual Status Visit(const DateType& type); + virtual Status Visit(const TimeType& type); + virtual Status Visit(const TimestampType& type); + virtual Status Visit(const IntervalType& type); + virtual Status Visit(const DecimalType& type); + virtual Status Visit(const ListType& type); + virtual Status Visit(const StructType& type); + virtual Status Visit(const UnionType& type); + virtual Status Visit(const DictionaryType& type); }; struct ARROW_EXPORT DataType { @@ -156,10 +156,7 @@ struct ARROW_EXPORT DataType { // Types that are logically convertable from one to another e.g. List // and Binary are NOT equal). virtual bool Equals(const DataType& other) const; - - bool Equals(const std::shared_ptr& other) const { - return Equals(*other.get()); - } + bool Equals(const std::shared_ptr& other) const; std::shared_ptr child(int i) const { return children_[i]; } @@ -211,8 +208,6 @@ struct ARROW_EXPORT Field { bool nullable = true) : name(name), type(type), nullable(nullable) {} - bool operator==(const Field& other) const { return this->Equals(other); } - bool operator!=(const Field& other) const { return !this->Equals(other); } bool Equals(const Field& other) const; bool Equals(const std::shared_ptr& other) const; @@ -411,10 +406,7 @@ struct ARROW_EXPORT UnionType : public DataType { static constexpr Type::type type_id = Type::UNION; UnionType(const std::vector>& fields, - const std::vector& type_codes, UnionMode mode = UnionMode::SPARSE) - : DataType(Type::UNION), mode(mode), type_codes(type_codes) { - children_ = fields; - } + const std::vector& type_codes, UnionMode mode = UnionMode::SPARSE); std::string ToString() const override; static std::string name() { return "union"; } @@ -523,8 +515,6 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType { std::shared_ptr dictionary() const; - bool Equals(const DataType& other) const override; - Status Accept(TypeVisitor* visitor) const override; std::string ToString() const override;