arrow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From w...@apache.org
Subject arrow git commit: ARROW-1136: [C++] Add null checks for invalid streams
Date Fri, 23 Jun 2017 18:26:28 GMT
Repository: arrow
Updated Branches:
  refs/heads/master 6768f5268 -> 8bf567e63


ARROW-1136: [C++] Add null checks for invalid streams

Author: Wes McKinney <wes.mckinney@twosigma.com>

Closes #770 from wesm/ARROW-1136 and squashes the following commits:

6ae5cd82 [Wes McKinney] Centralize null checking
bc3ec207 [Wes McKinney] Add null checks for invalid streams


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/8bf567e6
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/8bf567e6
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/8bf567e6

Branch: refs/heads/master
Commit: 8bf567e636e6f8a7e779a6f89ad3169f3ffa9fba
Parents: 6768f52
Author: Wes McKinney <wes.mckinney@twosigma.com>
Authored: Fri Jun 23 14:26:12 2017 -0400
Committer: Wes McKinney <wes.mckinney@twosigma.com>
Committed: Fri Jun 23 14:26:12 2017 -0400

----------------------------------------------------------------------
 cpp/src/arrow/ipc/reader.cc      | 20 ++++++++++++++------
 python/pyarrow/tests/test_ipc.py | 10 ++++++++++
 2 files changed, 24 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/8bf567e6/cpp/src/arrow/ipc/reader.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc
index 2b7b90f..7fef847 100644
--- a/cpp/src/arrow/ipc/reader.cc
+++ b/cpp/src/arrow/ipc/reader.cc
@@ -162,7 +162,7 @@ static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block*
block) {
   return FileBlock(block->offset(), block->metaDataLength(), block->bodyLength());
 }
 
-static inline std::string message_type_name(Message::Type type) {
+static inline std::string FormatMessageType(Message::Type type) {
   switch (type) {
     case Message::SCHEMA:
       return "schema";
@@ -188,14 +188,22 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
     return ReadSchema();
   }
 
-  Status ReadNextMessage(Message::Type expected_type, std::shared_ptr<Message>* message)
{
+  Status ReadNextMessage(Message::Type expected_type, bool allow_null,
+      std::shared_ptr<Message>* message) {
     RETURN_NOT_OK(ReadMessage(stream_.get(), message));
 
+    if (!(*message) && !allow_null) {
+      std::stringstream ss;
+      ss << "Expected " << FormatMessageType(expected_type)
+         << " message in stream, was null or length 0";
+      return Status::Invalid(ss.str());
+    }
+
     if ((*message) == nullptr) { return Status::OK(); }
 
     if ((*message)->type() != expected_type) {
       std::stringstream ss;
-      ss << "Message not expected type: " << message_type_name(expected_type)
+      ss << "Message not expected type: " << FormatMessageType(expected_type)
          << ", was: " << (*message)->type();
       return Status::IOError(ss.str());
     }
@@ -213,7 +221,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
 
   Status ReadNextDictionary() {
     std::shared_ptr<Message> message;
-    RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, &message));
+    RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, false, &message));
 
     std::shared_ptr<Buffer> batch_body;
     RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body))
@@ -227,7 +235,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
 
   Status ReadSchema() {
     std::shared_ptr<Message> message;
-    RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message));
+    RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, false, &message));
 
     RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_));
 
@@ -243,7 +251,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl {
 
   Status GetNextRecordBatch(std::shared_ptr<RecordBatch>* batch) {
     std::shared_ptr<Message> message;
-    RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, &message));
+    RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, true, &message));
 
     if (message == nullptr) {
       // End of stream

http://git-wip-us.apache.org/repos/asf/arrow/blob/8bf567e6/python/pyarrow/tests/test_ipc.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py
index eeea39a..47ef756 100644
--- a/python/pyarrow/tests/test_ipc.py
+++ b/python/pyarrow/tests/test_ipc.py
@@ -72,6 +72,11 @@ class TestFile(MessagingTest, unittest.TestCase):
     def _get_writer(self, sink, schema):
         return pa.RecordBatchFileWriter(sink, schema)
 
+    def test_empty_file(self):
+        buf = io.BytesIO(b'')
+        with pytest.raises(pa.ArrowInvalid):
+            pa.open_file(buf)
+
     def test_simple_roundtrip(self):
         batches = self.write_batches()
         file_contents = self._get_source()
@@ -101,6 +106,11 @@ class TestStream(MessagingTest, unittest.TestCase):
     def _get_writer(self, sink, schema):
         return pa.RecordBatchStreamWriter(sink, schema)
 
+    def test_empty_stream(self):
+        buf = io.BytesIO(b'')
+        with pytest.raises(pa.ArrowInvalid):
+            pa.open_stream(buf)
+
     def test_simple_roundtrip(self):
         batches = self.write_batches()
         file_contents = self._get_source()


Mime
View raw message