arrow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From w...@apache.org
Subject arrow git commit: ARROW-451: [C++] Implement DataType::Equals as TypeVisitor. Add default implementations for TypeVisitor, ArrayVisitor methods
Date Sun, 26 Feb 2017 23:25:09 GMT
Repository: arrow
Updated Branches:
  refs/heads/master 8afe92c6c -> ef3b6b344


ARROW-451: [C++] Implement DataType::Equals as TypeVisitor. Add default implementations for
TypeVisitor, ArrayVisitor methods

This patch also resolves ARROW-568. Added tests for TimeType, TimestampType, which were not
having their `unit` metadata compared due to an oversight.

Author: Wes McKinney <wes.mckinney@twosigma.com>

Closes #350 from wesm/ARROW-451 and squashes the following commits:

97e75d8 [Wes McKinney] Export ArrayVisitor, TypeVisitor symbols
a3332be [Wes McKinney] Typo
635e74d [Wes McKinney] Implement DataType::Equals as TypeVisitor, compare child metadata.
Add default implementations for TypeVisitor, ArrayVisitor methods


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/ef3b6b34
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/ef3b6b34
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/ef3b6b34

Branch: refs/heads/master
Commit: ef3b6b34482c36615af5064f474363126e755a18
Parents: 8afe92c
Author: Wes McKinney <wes.mckinney@twosigma.com>
Authored: Sun Feb 26 18:25:03 2017 -0500
Committer: Wes McKinney <wes.mckinney@twosigma.com>
Committed: Sun Feb 26 18:25:03 2017 -0500

----------------------------------------------------------------------
 cpp/src/arrow/CMakeLists.txt       |   2 +-
 cpp/src/arrow/array.cc             |  36 ++++++++
 cpp/src/arrow/array.h              |  50 +++++------
 cpp/src/arrow/compare.cc           | 108 +++++++++++++++++++++--
 cpp/src/arrow/compare.h            |   5 ++
 cpp/src/arrow/ipc/adapter.cc       |  20 -----
 cpp/src/arrow/ipc/json-internal.cc |  30 -------
 cpp/src/arrow/schema-test.cc       | 122 --------------------------
 cpp/src/arrow/type-test.cc         | 146 ++++++++++++++++++++++++++++++++
 cpp/src/arrow/type.cc              |  69 +++++++++++----
 cpp/src/arrow/type.h               |  64 ++++++--------
 11 files changed, 394 insertions(+), 258 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 824ced1..d1efa02 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -58,8 +58,8 @@ ADD_ARROW_TEST(buffer-test)
 ADD_ARROW_TEST(column-test)
 ADD_ARROW_TEST(memory_pool-test)
 ADD_ARROW_TEST(pretty_print-test)
-ADD_ARROW_TEST(schema-test)
 ADD_ARROW_TEST(status-test)
+ADD_ARROW_TEST(type-test)
 ADD_ARROW_TEST(table-test)
 
 ADD_ARROW_BENCHMARK(builder-benchmark)

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/array.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 81678e3..eb4c210 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -503,4 +503,40 @@ Status MakePrimitiveArray(const std::shared_ptr<DataType>&
type, int32_t length,
 #endif
 }
 
+// ----------------------------------------------------------------------
+// Default implementations of ArrayVisitor methods
+
+#define ARRAY_VISITOR_DEFAULT(ARRAY_CLASS)                   \
+  Status ArrayVisitor::Visit(const ARRAY_CLASS& array) {     \
+    return Status::NotImplemented(array.type()->ToString()); \
+  }
+
+ARRAY_VISITOR_DEFAULT(NullArray);
+ARRAY_VISITOR_DEFAULT(BooleanArray);
+ARRAY_VISITOR_DEFAULT(Int8Array);
+ARRAY_VISITOR_DEFAULT(Int16Array);
+ARRAY_VISITOR_DEFAULT(Int32Array);
+ARRAY_VISITOR_DEFAULT(Int64Array);
+ARRAY_VISITOR_DEFAULT(UInt8Array);
+ARRAY_VISITOR_DEFAULT(UInt16Array);
+ARRAY_VISITOR_DEFAULT(UInt32Array);
+ARRAY_VISITOR_DEFAULT(UInt64Array);
+ARRAY_VISITOR_DEFAULT(HalfFloatArray);
+ARRAY_VISITOR_DEFAULT(FloatArray);
+ARRAY_VISITOR_DEFAULT(DoubleArray);
+ARRAY_VISITOR_DEFAULT(StringArray);
+ARRAY_VISITOR_DEFAULT(BinaryArray);
+ARRAY_VISITOR_DEFAULT(DateArray);
+ARRAY_VISITOR_DEFAULT(TimeArray);
+ARRAY_VISITOR_DEFAULT(TimestampArray);
+ARRAY_VISITOR_DEFAULT(IntervalArray);
+ARRAY_VISITOR_DEFAULT(ListArray);
+ARRAY_VISITOR_DEFAULT(StructArray);
+ARRAY_VISITOR_DEFAULT(UnionArray);
+ARRAY_VISITOR_DEFAULT(DictionaryArray);
+
+Status ArrayVisitor::Visit(const DecimalArray& array) {
+  return Status::NotImplemented("decimal");
+}
+
 }  // namespace arrow

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/array.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 9bb06af..8bb914e 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -38,34 +38,34 @@ class MemoryPool;
 class MutableBuffer;
 class Status;
 
-class ArrayVisitor {
+class ARROW_EXPORT ArrayVisitor {
  public:
   virtual ~ArrayVisitor() = default;
 
-  virtual Status Visit(const NullArray& array) = 0;
-  virtual Status Visit(const BooleanArray& array) = 0;
-  virtual Status Visit(const Int8Array& array) = 0;
-  virtual Status Visit(const Int16Array& array) = 0;
-  virtual Status Visit(const Int32Array& array) = 0;
-  virtual Status Visit(const Int64Array& array) = 0;
-  virtual Status Visit(const UInt8Array& array) = 0;
-  virtual Status Visit(const UInt16Array& array) = 0;
-  virtual Status Visit(const UInt32Array& array) = 0;
-  virtual Status Visit(const UInt64Array& array) = 0;
-  virtual Status Visit(const HalfFloatArray& array) = 0;
-  virtual Status Visit(const FloatArray& array) = 0;
-  virtual Status Visit(const DoubleArray& array) = 0;
-  virtual Status Visit(const StringArray& array) = 0;
-  virtual Status Visit(const BinaryArray& array) = 0;
-  virtual Status Visit(const DateArray& array) = 0;
-  virtual Status Visit(const TimeArray& array) = 0;
-  virtual Status Visit(const TimestampArray& array) = 0;
-  virtual Status Visit(const IntervalArray& array) = 0;
-  virtual Status Visit(const DecimalArray& array) = 0;
-  virtual Status Visit(const ListArray& array) = 0;
-  virtual Status Visit(const StructArray& array) = 0;
-  virtual Status Visit(const UnionArray& array) = 0;
-  virtual Status Visit(const DictionaryArray& type) = 0;
+  virtual Status Visit(const NullArray& array);
+  virtual Status Visit(const BooleanArray& array);
+  virtual Status Visit(const Int8Array& array);
+  virtual Status Visit(const Int16Array& array);
+  virtual Status Visit(const Int32Array& array);
+  virtual Status Visit(const Int64Array& array);
+  virtual Status Visit(const UInt8Array& array);
+  virtual Status Visit(const UInt16Array& array);
+  virtual Status Visit(const UInt32Array& array);
+  virtual Status Visit(const UInt64Array& array);
+  virtual Status Visit(const HalfFloatArray& array);
+  virtual Status Visit(const FloatArray& array);
+  virtual Status Visit(const DoubleArray& array);
+  virtual Status Visit(const StringArray& array);
+  virtual Status Visit(const BinaryArray& array);
+  virtual Status Visit(const DateArray& array);
+  virtual Status Visit(const TimeArray& array);
+  virtual Status Visit(const TimestampArray& array);
+  virtual Status Visit(const IntervalArray& array);
+  virtual Status Visit(const DecimalArray& array);
+  virtual Status Visit(const ListArray& array);
+  virtual Status Visit(const StructArray& array);
+  virtual Status Visit(const UnionArray& array);
+  virtual Status Visit(const DictionaryArray& type);
 };
 
 /// Immutable data array with some logical type and some length.

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/compare.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index 21fdb66..ff3c59f 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -301,9 +301,9 @@ class RangeEqualsVisitor : public ArrayVisitor {
   bool result_;
 };
 
-class EqualsVisitor : public RangeEqualsVisitor {
+class ArrayEqualsVisitor : public RangeEqualsVisitor {
  public:
-  explicit EqualsVisitor(const Array& right)
+  explicit ArrayEqualsVisitor(const Array& right)
       : RangeEqualsVisitor(right, 0, right.length(), 0) {}
 
   Status Visit(const NullArray& left) override { return Status::OK(); }
@@ -511,9 +511,9 @@ inline bool FloatingApproxEquals(
   return true;
 }
 
-class ApproxEqualsVisitor : public EqualsVisitor {
+class ApproxEqualsVisitor : public ArrayEqualsVisitor {
  public:
-  using EqualsVisitor::EqualsVisitor;
+  using ArrayEqualsVisitor::ArrayEqualsVisitor;
 
   Status Visit(const FloatArray& left) override {
     result_ =
@@ -549,7 +549,7 @@ Status ArrayEquals(const Array& left, const Array& right, bool*
are_equal) {
   } else if (left.length() == 0) {
     *are_equal = true;
   } else {
-    EqualsVisitor visitor(right);
+    ArrayEqualsVisitor visitor(right);
     RETURN_NOT_OK(left.Accept(&visitor));
     *are_equal = visitor.result();
   }
@@ -588,4 +588,102 @@ Status ArrayApproxEquals(const Array& left, const Array& right,
bool* are_equal)
   return Status::OK();
 }
 
+// ----------------------------------------------------------------------
+// Implement TypeEquals
+
+class TypeEqualsVisitor : public TypeVisitor {
+ public:
+  explicit TypeEqualsVisitor(const DataType& right) : right_(right), result_(false) {}
+
+  Status VisitChildren(const DataType& left) {
+    if (left.num_children() != right_.num_children()) {
+      result_ = false;
+      return Status::OK();
+    }
+
+    for (int i = 0; i < left.num_children(); ++i) {
+      if (!left.child(i)->Equals(right_.child(i))) {
+        result_ = false;
+        break;
+      }
+    }
+    result_ = true;
+    return Status::OK();
+  }
+
+  Status Visit(const TimeType& left) override {
+    const auto& right = static_cast<const TimeType&>(right_);
+    result_ = left.unit == right.unit;
+    return Status::OK();
+  }
+
+  Status Visit(const TimestampType& left) override {
+    const auto& right = static_cast<const TimestampType&>(right_);
+    result_ = left.unit == right.unit;
+    return Status::OK();
+  }
+
+  Status Visit(const ListType& left) override { return VisitChildren(left); }
+
+  Status Visit(const StructType& left) override { return VisitChildren(left); }
+
+  Status Visit(const UnionType& left) override {
+    const auto& right = static_cast<const UnionType&>(right_);
+
+    if (left.mode != right.mode || left.type_codes.size() != right.type_codes.size()) {
+      result_ = false;
+      return Status::OK();
+    }
+
+    const std::vector<uint8_t> left_codes = left.type_codes;
+    const std::vector<uint8_t> right_codes = right.type_codes;
+
+    for (size_t i = 0; i < left_codes.size(); ++i) {
+      if (left_codes[i] != right_codes[i]) {
+        result_ = false;
+        break;
+      }
+    }
+    result_ = true;
+    return Status::OK();
+  }
+
+  Status Visit(const DictionaryType& left) override {
+    const auto& right = static_cast<const DictionaryType&>(right_);
+    result_ = left.index_type()->Equals(right.index_type()) &&
+              left.dictionary()->Equals(right.dictionary());
+    return Status::OK();
+  }
+
+  bool result() const { return result_; }
+
+ protected:
+  const DataType& right_;
+  bool result_;
+};
+
+Status TypeEquals(const DataType& left, const DataType& right, bool* are_equal) {
+  // The arrays are the same object
+  if (&left == &right) {
+    *are_equal = true;
+  } else if (left.type != right.type) {
+    *are_equal = false;
+  } else {
+    TypeEqualsVisitor visitor(right);
+    Status s = left.Accept(&visitor);
+
+    // We do not implement any type visitors where there is no additional
+    // metadata to compare.
+    if (s.IsNotImplemented()) {
+      // Not implemented means there is no additional metadata to compare
+      *are_equal = true;
+    } else if (!s.ok()) {
+      return s;
+    } else {
+      *are_equal = visitor.result();
+    }
+  }
+  return Status::OK();
+}
+
 }  // namespace arrow

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/compare.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h
index 2093b65..6a71f9f 100644
--- a/cpp/src/arrow/compare.h
+++ b/cpp/src/arrow/compare.h
@@ -27,6 +27,7 @@
 namespace arrow {
 
 class Array;
+struct DataType;
 class Status;
 
 /// Returns true if the arrays are exactly equal
@@ -41,6 +42,10 @@ Status ARROW_EXPORT ArrayApproxEquals(
 Status ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right,
     int32_t start_idx, int32_t end_idx, int32_t other_start_idx, bool* are_equal);
 
+/// Returns true if the type metadata are exactly equal
+Status ARROW_EXPORT TypeEquals(
+    const DataType& left, const DataType& right, bool* are_equal);
+
 }  // namespace arrow
 
 #endif  // ARROW_COMPARE_H

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/ipc/adapter.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc
index 08ac983..2be87a3 100644
--- a/cpp/src/arrow/ipc/adapter.cc
+++ b/cpp/src/arrow/ipc/adapter.cc
@@ -227,8 +227,6 @@ class RecordBatchWriter : public ArrayVisitor {
   }
 
  protected:
-  Status Visit(const NullArray& array) override { return Status::NotImplemented("null");
}
-
   template <typename ArrayType>
   Status VisitFixedWidth(const ArrayType& array) {
     std::shared_ptr<Buffer> data_buffer = array.data();
@@ -360,14 +358,6 @@ class RecordBatchWriter : public ArrayVisitor {
     return VisitFixedWidth<TimestampArray>(array);
   }
 
-  Status Visit(const IntervalArray& array) override {
-    return Status::NotImplemented("interval");
-  }
-
-  Status Visit(const DecimalArray& array) override {
-    return Status::NotImplemented("decimal");
-  }
-
   Status Visit(const ListArray& array) override {
     std::shared_ptr<Buffer> value_offsets;
     RETURN_NOT_OK(GetZeroBasedValueOffsets<ListArray>(array, &value_offsets));
@@ -653,8 +643,6 @@ class ArrayLoader : public TypeVisitor {
     return Status::OK();
   }
 
-  Status Visit(const NullType& type) override { return Status::NotImplemented("null");
}
-
   Status Visit(const BooleanType& type) override { return LoadPrimitive(type); }
 
   Status Visit(const Int8Type& type) override { return LoadPrimitive(type); }
@@ -689,14 +677,6 @@ class ArrayLoader : public TypeVisitor {
 
   Status Visit(const TimestampType& type) override { return LoadPrimitive(type); }
 
-  Status Visit(const IntervalType& type) override {
-    return Status::NotImplemented(type.ToString());
-  }
-
-  Status Visit(const DecimalType& type) override {
-    return Status::NotImplemented(type.ToString());
-  }
-
   Status Visit(const ListType& type) override {
     FieldMetadata field_meta;
     std::shared_ptr<Buffer> null_bitmap, offsets;

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/ipc/json-internal.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc
index b9f97dd..6253cd6 100644
--- a/cpp/src/arrow/ipc/json-internal.cc
+++ b/cpp/src/arrow/ipc/json-internal.cc
@@ -316,8 +316,6 @@ class JsonSchemaWriter : public TypeVisitor {
     return WritePrimitive("interval", type);
   }
 
-  Status Visit(const DecimalType& type) override { return Status::NotImplemented("NYI");
}
-
   Status Visit(const ListType& type) override {
     WriteName("list", type);
     RETURN_NOT_OK(WriteChildren(type.children()));
@@ -339,14 +337,6 @@ class JsonSchemaWriter : public TypeVisitor {
     return Status::OK();
   }
 
-  Status Visit(const DictionaryType& type) override {
-    // WriteName("dictionary", type);
-    // WriteChildren(type.children());
-    // WriteBufferLayout(type.GetBufferLayout());
-    // return Status::OK();
-    return Status::NotImplemented("dictionary type");
-  }
-
  private:
   const Schema& schema_;
   RjWriter* writer_;
@@ -531,22 +521,6 @@ class JsonArrayWriter : public ArrayVisitor {
 
   Status Visit(const BinaryArray& array) override { return WriteVarBytes(array); }
 
-  Status Visit(const DateArray& array) override { return Status::NotImplemented("date");
}
-
-  Status Visit(const TimeArray& array) override { return Status::NotImplemented("time");
}
-
-  Status Visit(const TimestampArray& array) override {
-    return Status::NotImplemented("timestamp");
-  }
-
-  Status Visit(const IntervalArray& array) override {
-    return Status::NotImplemented("interval");
-  }
-
-  Status Visit(const DecimalArray& array) override {
-    return Status::NotImplemented("decimal");
-  }
-
   Status Visit(const ListArray& array) override {
     WriteValidityField(array);
     WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length() + 1);
@@ -571,10 +545,6 @@ class JsonArrayWriter : public ArrayVisitor {
     return WriteChildren(type->children(), array.children());
   }
 
-  Status Visit(const DictionaryArray& array) override {
-    return Status::NotImplemented("dictionary");
-  }
-
  private:
   const std::string& name_;
   const Array& array_;

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/schema-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/schema-test.cc b/cpp/src/arrow/schema-test.cc
deleted file mode 100644
index 4826199..0000000
--- a/cpp/src/arrow/schema-test.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "gtest/gtest.h"
-
-#include "arrow/schema.h"
-#include "arrow/type.h"
-
-using std::shared_ptr;
-using std::vector;
-
-namespace arrow {
-
-TEST(TestField, Basics) {
-  Field f0("f0", int32());
-  Field f0_nn("f0", int32(), false);
-
-  ASSERT_EQ(f0.name, "f0");
-  ASSERT_EQ(f0.type->ToString(), int32()->ToString());
-
-  ASSERT_TRUE(f0.nullable);
-  ASSERT_FALSE(f0_nn.nullable);
-}
-
-TEST(TestField, Equals) {
-  Field f0("f0", int32());
-  Field f0_nn("f0", int32(), false);
-  Field f0_other("f0", int32());
-
-  ASSERT_EQ(f0, f0_other);
-  ASSERT_NE(f0, f0_nn);
-}
-
-class TestSchema : public ::testing::Test {
- public:
-  void SetUp() {}
-};
-
-TEST_F(TestSchema, Basics) {
-  auto f0 = field("f0", int32());
-  auto f1 = field("f1", uint8(), false);
-  auto f1_optional = field("f1", uint8());
-
-  auto f2 = field("f2", utf8());
-
-  vector<shared_ptr<Field>> fields = {f0, f1, f2};
-  auto schema = std::make_shared<Schema>(fields);
-
-  ASSERT_EQ(3, schema->num_fields());
-  ASSERT_EQ(f0, schema->field(0));
-  ASSERT_EQ(f1, schema->field(1));
-  ASSERT_EQ(f2, schema->field(2));
-
-  auto schema2 = std::make_shared<Schema>(fields);
-
-  vector<shared_ptr<Field>> fields3 = {f0, f1_optional, f2};
-  auto schema3 = std::make_shared<Schema>(fields3);
-  ASSERT_TRUE(schema->Equals(schema2));
-  ASSERT_FALSE(schema->Equals(schema3));
-
-  ASSERT_TRUE(schema->Equals(*schema2.get()));
-  ASSERT_FALSE(schema->Equals(*schema3.get()));
-}
-
-TEST_F(TestSchema, ToString) {
-  auto f0 = field("f0", int32());
-  auto f1 = field("f1", uint8(), false);
-  auto f2 = field("f2", utf8());
-  auto f3 = field("f3", list(int16()));
-
-  vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
-  auto schema = std::make_shared<Schema>(fields);
-
-  std::string result = schema->ToString();
-  std::string expected = R"(f0: int32
-f1: uint8 not null
-f2: string
-f3: list<item: int16>)";
-
-  ASSERT_EQ(expected, result);
-}
-
-TEST_F(TestSchema, GetFieldByName) {
-  auto f0 = field("f0", int32());
-  auto f1 = field("f1", uint8(), false);
-  auto f2 = field("f2", utf8());
-  auto f3 = field("f3", list(int16()));
-
-  vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
-  auto schema = std::make_shared<Schema>(fields);
-
-  std::shared_ptr<Field> result;
-
-  result = schema->GetFieldByName("f1");
-  ASSERT_TRUE(f1->Equals(result));
-
-  result = schema->GetFieldByName("f3");
-  ASSERT_TRUE(f3->Equals(result));
-
-  result = schema->GetFieldByName("not-found");
-  ASSERT_TRUE(result == nullptr);
-}
-
-}  // namespace arrow

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc
new file mode 100644
index 0000000..fe6c62a
--- /dev/null
+++ b/cpp/src/arrow/type-test.cc
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Unit tests for DataType (and subclasses), Field, and Schema
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "arrow/schema.h"
+#include "arrow/type.h"
+
+using std::shared_ptr;
+using std::vector;
+
+namespace arrow {
+
+TEST(TestField, Basics) {
+  Field f0("f0", int32());
+  Field f0_nn("f0", int32(), false);
+
+  ASSERT_EQ(f0.name, "f0");
+  ASSERT_EQ(f0.type->ToString(), int32()->ToString());
+
+  ASSERT_TRUE(f0.nullable);
+  ASSERT_FALSE(f0_nn.nullable);
+}
+
+TEST(TestField, Equals) {
+  Field f0("f0", int32());
+  Field f0_nn("f0", int32(), false);
+  Field f0_other("f0", int32());
+
+  ASSERT_TRUE(f0.Equals(f0_other));
+  ASSERT_FALSE(f0.Equals(f0_nn));
+}
+
+class TestSchema : public ::testing::Test {
+ public:
+  void SetUp() {}
+};
+
+TEST_F(TestSchema, Basics) {
+  auto f0 = field("f0", int32());
+  auto f1 = field("f1", uint8(), false);
+  auto f1_optional = field("f1", uint8());
+
+  auto f2 = field("f2", utf8());
+
+  vector<shared_ptr<Field>> fields = {f0, f1, f2};
+  auto schema = std::make_shared<Schema>(fields);
+
+  ASSERT_EQ(3, schema->num_fields());
+  ASSERT_TRUE(f0->Equals(schema->field(0)));
+  ASSERT_TRUE(f1->Equals(schema->field(1)));
+  ASSERT_TRUE(f2->Equals(schema->field(2)));
+
+  auto schema2 = std::make_shared<Schema>(fields);
+
+  vector<shared_ptr<Field>> fields3 = {f0, f1_optional, f2};
+  auto schema3 = std::make_shared<Schema>(fields3);
+  ASSERT_TRUE(schema->Equals(schema2));
+  ASSERT_FALSE(schema->Equals(schema3));
+
+  ASSERT_TRUE(schema->Equals(*schema2.get()));
+  ASSERT_FALSE(schema->Equals(*schema3.get()));
+}
+
+TEST_F(TestSchema, ToString) {
+  auto f0 = field("f0", int32());
+  auto f1 = field("f1", uint8(), false);
+  auto f2 = field("f2", utf8());
+  auto f3 = field("f3", list(int16()));
+
+  vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
+  auto schema = std::make_shared<Schema>(fields);
+
+  std::string result = schema->ToString();
+  std::string expected = R"(f0: int32
+f1: uint8 not null
+f2: string
+f3: list<item: int16>)";
+
+  ASSERT_EQ(expected, result);
+}
+
+TEST_F(TestSchema, GetFieldByName) {
+  auto f0 = field("f0", int32());
+  auto f1 = field("f1", uint8(), false);
+  auto f2 = field("f2", utf8());
+  auto f3 = field("f3", list(int16()));
+
+  vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
+  auto schema = std::make_shared<Schema>(fields);
+
+  std::shared_ptr<Field> result;
+
+  result = schema->GetFieldByName("f1");
+  ASSERT_TRUE(f1->Equals(result));
+
+  result = schema->GetFieldByName("f3");
+  ASSERT_TRUE(f3->Equals(result));
+
+  result = schema->GetFieldByName("not-found");
+  ASSERT_TRUE(result == nullptr);
+}
+
+TEST(TestTimeType, Equals) {
+  TimeType t1;
+  TimeType t2;
+  TimeType t3(TimeUnit::NANO);
+  TimeType t4(TimeUnit::NANO);
+
+  ASSERT_TRUE(t1.Equals(t2));
+  ASSERT_FALSE(t1.Equals(t3));
+  ASSERT_TRUE(t3.Equals(t4));
+}
+
+TEST(TestTimestampType, Equals) {
+  TimestampType t1;
+  TimestampType t2;
+  TimestampType t3(TimeUnit::NANO);
+  TimestampType t4(TimeUnit::NANO);
+
+  ASSERT_TRUE(t1.Equals(t2));
+  ASSERT_FALSE(t1.Equals(t3));
+  ASSERT_TRUE(t3.Equals(t4));
+}
+
+}  // namespace arrow

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index b97b465..23fa681 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -21,6 +21,7 @@
 #include <string>
 
 #include "arrow/array.h"
+#include "arrow/compare.h"
 #include "arrow/status.h"
 #include "arrow/util/logging.h"
 
@@ -46,16 +47,14 @@ std::string Field::ToString() const {
 DataType::~DataType() {}
 
 bool DataType::Equals(const DataType& other) const {
-  bool equals =
-      ((this == &other) || ((this->type == other.type) &&
-                               ((this->num_children() == other.num_children()))));
-  if (equals) {
-    for (int i = 0; i < num_children(); ++i) {
-      // TODO(emkornfield) limit recursion
-      if (!children_[i]->Equals(other.children_[i])) { return false; }
-    }
-  }
-  return equals;
+  bool are_equal = false;
+  Status error = TypeEquals(*this, other, &are_equal);
+  if (!error.ok()) { DCHECK(false) << "Types not comparable: " << error.ToString();
}
+  return are_equal;
+}
+
+bool DataType::Equals(const std::shared_ptr<DataType>& other) const {
+  return Equals(*other.get());
 }
 
 std::string BooleanType::ToString() const {
@@ -104,6 +103,15 @@ std::string DateType::ToString() const {
   return std::string("date");
 }
 
+// ----------------------------------------------------------------------
+// Union type
+
+UnionType::UnionType(const std::vector<std::shared_ptr<Field>>& fields,
+    const std::vector<uint8_t>& type_codes, UnionMode mode)
+    : DataType(Type::UNION), mode(mode), type_codes(type_codes) {
+  children_ = fields;
+}
+
 std::string UnionType::ToString() const {
   std::stringstream s;
 
@@ -138,14 +146,6 @@ std::shared_ptr<Array> DictionaryType::dictionary() const {
   return dictionary_;
 }
 
-bool DictionaryType::Equals(const DataType& other) const {
-  if (other.type != Type::DICTIONARY) { return false; }
-  const auto& other_dict = static_cast<const DictionaryType&>(other);
-
-  return index_type_->Equals(other_dict.index_type_) &&
-         dictionary_->Equals(other_dict.dictionary_);
-}
-
 std::string DictionaryType::ToString() const {
   std::stringstream ss;
   ss << "dictionary<values=" << dictionary_->type()->ToString()
@@ -286,4 +286,37 @@ std::vector<BufferDescr> DecimalType::GetBufferLayout() const {
   return {};
 }
 
+// ----------------------------------------------------------------------
+// Default implementations of TypeVisitor methods
+
+#define TYPE_VISITOR_DEFAULT(TYPE_CLASS)              \
+  Status TypeVisitor::Visit(const TYPE_CLASS& type) { \
+    return Status::NotImplemented(type.ToString());   \
+  }
+
+TYPE_VISITOR_DEFAULT(NullType);
+TYPE_VISITOR_DEFAULT(BooleanType);
+TYPE_VISITOR_DEFAULT(Int8Type);
+TYPE_VISITOR_DEFAULT(Int16Type);
+TYPE_VISITOR_DEFAULT(Int32Type);
+TYPE_VISITOR_DEFAULT(Int64Type);
+TYPE_VISITOR_DEFAULT(UInt8Type);
+TYPE_VISITOR_DEFAULT(UInt16Type);
+TYPE_VISITOR_DEFAULT(UInt32Type);
+TYPE_VISITOR_DEFAULT(UInt64Type);
+TYPE_VISITOR_DEFAULT(HalfFloatType);
+TYPE_VISITOR_DEFAULT(FloatType);
+TYPE_VISITOR_DEFAULT(DoubleType);
+TYPE_VISITOR_DEFAULT(StringType);
+TYPE_VISITOR_DEFAULT(BinaryType);
+TYPE_VISITOR_DEFAULT(DateType);
+TYPE_VISITOR_DEFAULT(TimeType);
+TYPE_VISITOR_DEFAULT(TimestampType);
+TYPE_VISITOR_DEFAULT(IntervalType);
+TYPE_VISITOR_DEFAULT(DecimalType);
+TYPE_VISITOR_DEFAULT(ListType);
+TYPE_VISITOR_DEFAULT(StructType);
+TYPE_VISITOR_DEFAULT(UnionType);
+TYPE_VISITOR_DEFAULT(DictionaryType);
+
 }  // namespace arrow

http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index b15aa27..9a97fc3 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -112,34 +112,34 @@ class BufferDescr {
   int bit_width_;
 };
 
-class TypeVisitor {
+class ARROW_EXPORT TypeVisitor {
  public:
   virtual ~TypeVisitor() = default;
 
-  virtual Status Visit(const NullType& type) = 0;
-  virtual Status Visit(const BooleanType& type) = 0;
-  virtual Status Visit(const Int8Type& type) = 0;
-  virtual Status Visit(const Int16Type& type) = 0;
-  virtual Status Visit(const Int32Type& type) = 0;
-  virtual Status Visit(const Int64Type& type) = 0;
-  virtual Status Visit(const UInt8Type& type) = 0;
-  virtual Status Visit(const UInt16Type& type) = 0;
-  virtual Status Visit(const UInt32Type& type) = 0;
-  virtual Status Visit(const UInt64Type& type) = 0;
-  virtual Status Visit(const HalfFloatType& type) = 0;
-  virtual Status Visit(const FloatType& type) = 0;
-  virtual Status Visit(const DoubleType& type) = 0;
-  virtual Status Visit(const StringType& type) = 0;
-  virtual Status Visit(const BinaryType& type) = 0;
-  virtual Status Visit(const DateType& type) = 0;
-  virtual Status Visit(const TimeType& type) = 0;
-  virtual Status Visit(const TimestampType& type) = 0;
-  virtual Status Visit(const IntervalType& type) = 0;
-  virtual Status Visit(const DecimalType& type) = 0;
-  virtual Status Visit(const ListType& type) = 0;
-  virtual Status Visit(const StructType& type) = 0;
-  virtual Status Visit(const UnionType& type) = 0;
-  virtual Status Visit(const DictionaryType& type) = 0;
+  virtual Status Visit(const NullType& type);
+  virtual Status Visit(const BooleanType& type);
+  virtual Status Visit(const Int8Type& type);
+  virtual Status Visit(const Int16Type& type);
+  virtual Status Visit(const Int32Type& type);
+  virtual Status Visit(const Int64Type& type);
+  virtual Status Visit(const UInt8Type& type);
+  virtual Status Visit(const UInt16Type& type);
+  virtual Status Visit(const UInt32Type& type);
+  virtual Status Visit(const UInt64Type& type);
+  virtual Status Visit(const HalfFloatType& type);
+  virtual Status Visit(const FloatType& type);
+  virtual Status Visit(const DoubleType& type);
+  virtual Status Visit(const StringType& type);
+  virtual Status Visit(const BinaryType& type);
+  virtual Status Visit(const DateType& type);
+  virtual Status Visit(const TimeType& type);
+  virtual Status Visit(const TimestampType& type);
+  virtual Status Visit(const IntervalType& type);
+  virtual Status Visit(const DecimalType& type);
+  virtual Status Visit(const ListType& type);
+  virtual Status Visit(const StructType& type);
+  virtual Status Visit(const UnionType& type);
+  virtual Status Visit(const DictionaryType& type);
 };
 
 struct ARROW_EXPORT DataType {
@@ -156,10 +156,7 @@ struct ARROW_EXPORT DataType {
   // Types that are logically convertable from one to another e.g. List<UInt8>
   // and Binary are NOT equal).
   virtual bool Equals(const DataType& other) const;
-
-  bool Equals(const std::shared_ptr<DataType>& other) const {
-    return Equals(*other.get());
-  }
+  bool Equals(const std::shared_ptr<DataType>& other) const;
 
   std::shared_ptr<Field> child(int i) const { return children_[i]; }
 
@@ -211,8 +208,6 @@ struct ARROW_EXPORT Field {
       bool nullable = true)
       : name(name), type(type), nullable(nullable) {}
 
-  bool operator==(const Field& other) const { return this->Equals(other); }
-  bool operator!=(const Field& other) const { return !this->Equals(other); }
   bool Equals(const Field& other) const;
   bool Equals(const std::shared_ptr<Field>& other) const;
 
@@ -411,10 +406,7 @@ struct ARROW_EXPORT UnionType : public DataType {
   static constexpr Type::type type_id = Type::UNION;
 
   UnionType(const std::vector<std::shared_ptr<Field>>& fields,
-      const std::vector<uint8_t>& type_codes, UnionMode mode = UnionMode::SPARSE)
-      : DataType(Type::UNION), mode(mode), type_codes(type_codes) {
-    children_ = fields;
-  }
+      const std::vector<uint8_t>& type_codes, UnionMode mode = UnionMode::SPARSE);
 
   std::string ToString() const override;
   static std::string name() { return "union"; }
@@ -523,8 +515,6 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType {
 
   std::shared_ptr<Array> dictionary() const;
 
-  bool Equals(const DataType& other) const override;
-
   Status Accept(TypeVisitor* visitor) const override;
   std::string ToString() const override;
 


Mime
View raw message