arrow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From w...@apache.org
Subject arrow git commit: ARROW-1418: [Python] Introduce SerializationContext to register custom serialization callbacks
Date Tue, 29 Aug 2017 18:32:44 GMT
Repository: arrow
Updated Branches:
  refs/heads/master 254078ef6 -> 441f96594


ARROW-1418: [Python] Introduce SerializationContext to register custom serialization callbacks

This gets rid of the global serialize and deserialize callbacks and exposes the functionality
to provide custom callbacks to the user.

Author: Philipp Moritz <pcmoritz@gmail.com>

Closes #1002 from pcmoritz/serialization-callbacks and squashes the following commits:

81f77796 [Philipp Moritz] Add SerializationContext to eliminate global type registry


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/441f9659
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/441f9659
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/441f9659

Branch: refs/heads/master
Commit: 441f96594156a798c59a2e925ec6d5aec4541ddc
Parents: 254078e
Author: Philipp Moritz <pcmoritz@gmail.com>
Authored: Tue Aug 29 14:32:38 2017 -0400
Committer: Wes McKinney <wes.mckinney@twosigma.com>
Committed: Tue Aug 29 14:32:38 2017 -0400

----------------------------------------------------------------------
 ci/travis_script_python.sh                 |   4 +-
 cpp/src/arrow/python/arrow_to_python.cc    |  93 ++++++-----
 cpp/src/arrow/python/arrow_to_python.h     |   9 +-
 cpp/src/arrow/python/python_to_arrow.cc    | 115 +++++++-------
 cpp/src/arrow/python/python_to_arrow.h     |  18 +--
 python/doc/source/api.rst                  |   1 +
 python/pyarrow/__init__.py                 |   2 +-
 python/pyarrow/includes/libarrow.pxd       |   9 +-
 python/pyarrow/serialization.pxi           | 201 +++++++++++++-----------
 python/pyarrow/tests/test_serialization.py |  93 ++++++-----
 10 files changed, 285 insertions(+), 260 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/ci/travis_script_python.sh
----------------------------------------------------------------------
diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh
index 66cd17d..b5ba136 100755
--- a/ci/travis_script_python.sh
+++ b/ci/travis_script_python.sh
@@ -95,10 +95,10 @@ python_version_tests() {
   conda install -y -q pip numpy pandas cython flake8
 
   # Fail fast on style checks
-  flake8 pyarrow
+  flake8 --count pyarrow
 
   # Check Cython files with some checks turned off
-  flake8 --config=.flake8.cython pyarrow
+  flake8 --count --config=.flake8.cython pyarrow
 
   # Build C++ libraries
   rebuild_arrow_libraries

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/cpp/src/arrow/python/arrow_to_python.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/python/arrow_to_python.cc b/cpp/src/arrow/python/arrow_to_python.cc
index dcb06f8..3ff1f11 100644
--- a/cpp/src/arrow/python/arrow_to_python.cc
+++ b/cpp/src/arrow/python/arrow_to_python.cc
@@ -27,40 +27,37 @@
 #include "arrow/python/common.h"
 #include "arrow/python/helpers.h"
 #include "arrow/python/numpy_convert.h"
+#include "arrow/python/python_to_arrow.h"
 #include "arrow/table.h"
 #include "arrow/util/logging.h"
 
-extern "C" {
-extern PyObject* pyarrow_serialize_callback;
-extern PyObject* pyarrow_deserialize_callback;
-}
-
 namespace arrow {
 namespace py {
 
-Status CallCustomCallback(PyObject* callback, PyObject* elem, PyObject** result);
+Status CallDeserializeCallback(PyObject* context, PyObject* value,
+                               PyObject** deserialized_object);
 
-Status DeserializeTuple(std::shared_ptr<Array> array, int64_t start_idx, int64_t stop_idx,
-                        PyObject* base,
+Status DeserializeTuple(PyObject* context, std::shared_ptr<Array> array,
+                        int64_t start_idx, int64_t stop_idx, PyObject* base,
                         const std::vector<std::shared_ptr<Tensor>>& tensors,
                         PyObject** out);
 
-Status DeserializeList(std::shared_ptr<Array> array, int64_t start_idx, int64_t stop_idx,
-                       PyObject* base,
+Status DeserializeList(PyObject* context, std::shared_ptr<Array> array, int64_t start_idx,
+                       int64_t stop_idx, PyObject* base,
                        const std::vector<std::shared_ptr<Tensor>>& tensors,
                        PyObject** out);
 
-Status DeserializeDict(std::shared_ptr<Array> array, int64_t start_idx, int64_t stop_idx,
-                       PyObject* base,
+Status DeserializeDict(PyObject* context, std::shared_ptr<Array> array, int64_t start_idx,
+                       int64_t stop_idx, PyObject* base,
                        const std::vector<std::shared_ptr<Tensor>>& tensors,
                        PyObject** out) {
   auto data = std::dynamic_pointer_cast<StructArray>(array);
   ScopedRef keys, vals;
   ScopedRef result(PyDict_New());
-  RETURN_NOT_OK(
-      DeserializeList(data->field(0), start_idx, stop_idx, base, tensors, keys.ref()));
-  RETURN_NOT_OK(
-      DeserializeList(data->field(1), start_idx, stop_idx, base, tensors, vals.ref()));
+  RETURN_NOT_OK(DeserializeList(context, data->field(0), start_idx, stop_idx, base,
+                                tensors, keys.ref()));
+  RETURN_NOT_OK(DeserializeList(context, data->field(1), start_idx, stop_idx, base,
+                                tensors, vals.ref()));
   for (int64_t i = start_idx; i < stop_idx; ++i) {
     // PyDict_SetItem behaves differently from PyList_SetItem and PyTuple_SetItem.
     // The latter two steal references whereas PyDict_SetItem does not. So we need
@@ -71,7 +68,7 @@ Status DeserializeDict(std::shared_ptr<Array> array, int64_t start_idx,
int64_t
   }
   static PyObject* py_type = PyUnicode_FromString("_pytype_");
   if (PyDict_Contains(result.get(), py_type)) {
-    RETURN_NOT_OK(CallCustomCallback(pyarrow_deserialize_callback, result.get(), out));
+    RETURN_NOT_OK(CallDeserializeCallback(context, result.get(), out));
   } else {
     *out = result.release();
   }
@@ -93,7 +90,8 @@ Status DeserializeArray(std::shared_ptr<Array> array, int64_t offset,
PyObject*
   return Status::OK();
 }
 
-Status GetValue(std::shared_ptr<Array> arr, int64_t index, int32_t type, PyObject*
base,
+Status GetValue(PyObject* context, std::shared_ptr<Array> arr, int64_t index,
+                int32_t type, PyObject* base,
                 const std::vector<std::shared_ptr<Tensor>>& tensors, PyObject**
result) {
   switch (arr->type()->id()) {
     case Type::BOOL:
@@ -130,13 +128,13 @@ Status GetValue(std::shared_ptr<Array> arr, int64_t index, int32_t
type, PyObjec
       auto s = std::static_pointer_cast<StructArray>(arr);
       auto l = std::static_pointer_cast<ListArray>(s->field(0));
       if (s->type()->child(0)->name() == "list") {
-        return DeserializeList(l->values(), l->value_offset(index),
+        return DeserializeList(context, l->values(), l->value_offset(index),
                                l->value_offset(index + 1), base, tensors, result);
       } else if (s->type()->child(0)->name() == "tuple") {
-        return DeserializeTuple(l->values(), l->value_offset(index),
+        return DeserializeTuple(context, l->values(), l->value_offset(index),
                                 l->value_offset(index + 1), base, tensors, result);
       } else if (s->type()->child(0)->name() == "dict") {
-        return DeserializeDict(l->values(), l->value_offset(index),
+        return DeserializeDict(context, l->values(), l->value_offset(index),
                                l->value_offset(index + 1), base, tensors, result);
       } else {
         DCHECK(false) << "unexpected StructArray type " << s->type()->child(0)->name();
@@ -153,37 +151,37 @@ Status GetValue(std::shared_ptr<Array> arr, int64_t index, int32_t
type, PyObjec
   return Status::OK();
 }
 
-#define DESERIALIZE_SEQUENCE(CREATE_FN, SET_ITEM_FN)                        \
-  auto data = std::dynamic_pointer_cast<UnionArray>(array);                 \
-  int64_t size = array->length();                                           \
-  ScopedRef result(CREATE_FN(stop_idx - start_idx));                        \
-  auto types = std::make_shared<Int8Array>(size, data->type_ids());         \
-  auto offsets = std::make_shared<Int32Array>(size, data->value_offsets()); \
-  for (int64_t i = start_idx; i < stop_idx; ++i) {                          \
-    if (data->IsNull(i)) {                                                  \
-      Py_INCREF(Py_None);                                                   \
-      SET_ITEM_FN(result.get(), i - start_idx, Py_None);                    \
-    } else {                                                                \
-      int64_t offset = offsets->Value(i);                                   \
-      int8_t type = types->Value(i);                                        \
-      std::shared_ptr<Array> arr = data->child(type);                       \
-      PyObject* value;                                                      \
-      RETURN_NOT_OK(GetValue(arr, offset, type, base, tensors, &value));    \
-      SET_ITEM_FN(result.get(), i - start_idx, value);                      \
-    }                                                                       \
-  }                                                                         \
-  *out = result.release();                                                  \
+#define DESERIALIZE_SEQUENCE(CREATE_FN, SET_ITEM_FN)                              \
+  auto data = std::dynamic_pointer_cast<UnionArray>(array);                       \
+  int64_t size = array->length();                                                 \
+  ScopedRef result(CREATE_FN(stop_idx - start_idx));                              \
+  auto types = std::make_shared<Int8Array>(size, data->type_ids());            
  \
+  auto offsets = std::make_shared<Int32Array>(size, data->value_offsets());    
  \
+  for (int64_t i = start_idx; i < stop_idx; ++i) {                                \
+    if (data->IsNull(i)) {                                                        \
+      Py_INCREF(Py_None);                                                         \
+      SET_ITEM_FN(result.get(), i - start_idx, Py_None);                          \
+    } else {                                                                      \
+      int64_t offset = offsets->Value(i);                                         \
+      int8_t type = types->Value(i);                                              \
+      std::shared_ptr<Array> arr = data->child(type);                          
  \
+      PyObject* value;                                                            \
+      RETURN_NOT_OK(GetValue(context, arr, offset, type, base, tensors, &value)); \
+      SET_ITEM_FN(result.get(), i - start_idx, value);                            \
+    }                                                                             \
+  }                                                                               \
+  *out = result.release();                                                        \
   return Status::OK();
 
-Status DeserializeList(std::shared_ptr<Array> array, int64_t start_idx, int64_t stop_idx,
-                       PyObject* base,
+Status DeserializeList(PyObject* context, std::shared_ptr<Array> array, int64_t start_idx,
+                       int64_t stop_idx, PyObject* base,
                        const std::vector<std::shared_ptr<Tensor>>& tensors,
                        PyObject** out) {
   DESERIALIZE_SEQUENCE(PyList_New, PyList_SET_ITEM)
 }
 
-Status DeserializeTuple(std::shared_ptr<Array> array, int64_t start_idx, int64_t stop_idx,
-                        PyObject* base,
+Status DeserializeTuple(PyObject* context, std::shared_ptr<Array> array,
+                        int64_t start_idx, int64_t stop_idx, PyObject* base,
                         const std::vector<std::shared_ptr<Tensor>>& tensors,
                         PyObject** out) {
   DESERIALIZE_SEQUENCE(PyTuple_New, PyTuple_SET_ITEM)
@@ -212,9 +210,10 @@ Status ReadSerializedObject(io::RandomAccessFile* src, SerializedPyObject*
out)
   return Status::OK();
 }
 
-Status DeserializeObject(const SerializedPyObject& obj, PyObject* base, PyObject** out)
{
+Status DeserializeObject(PyObject* context, const SerializedPyObject& obj, PyObject*
base,
+                         PyObject** out) {
   PyAcquireGIL lock;
-  return DeserializeList(obj.batch->column(0), 0, obj.batch->num_rows(), base,
+  return DeserializeList(context, obj.batch->column(0), 0, obj.batch->num_rows(), base,
                          obj.tensors, out);
 }
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/cpp/src/arrow/python/arrow_to_python.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/python/arrow_to_python.h b/cpp/src/arrow/python/arrow_to_python.h
index e187d59..3650c05 100644
--- a/cpp/src/arrow/python/arrow_to_python.h
+++ b/cpp/src/arrow/python/arrow_to_python.h
@@ -49,6 +49,11 @@ ARROW_EXPORT
 Status ReadSerializedObject(io::RandomAccessFile* src, SerializedPyObject* out);
 
 /// \brief Reconstruct Python object from Arrow-serialized representation
+/// \param[in] context Serialization context which contains custom serialization
+/// and deserialization callbacks. Can be any Python object with a
+/// _serialize_callback method for serialization and a _deserialize_callback
+/// method for deserialization. If context is None, no custom serialization
+/// will be attempted.
 /// \param[in] object
 /// \param[in] base a Python object holding the underlying data that any NumPy
 /// arrays will reference, to avoid premature deallocation
@@ -56,8 +61,8 @@ Status ReadSerializedObject(io::RandomAccessFile* src, SerializedPyObject*
out);
 /// \return Status
 /// This acquires the GIL
 ARROW_EXPORT
-Status DeserializeObject(const SerializedPyObject& object, PyObject* base,
-                         PyObject** out);
+Status DeserializeObject(PyObject* context, const SerializedPyObject& object,
+                         PyObject* base, PyObject** out);
 
 }  // namespace py
 }  // namespace arrow

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/cpp/src/arrow/python/python_to_arrow.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/python/python_to_arrow.cc b/cpp/src/arrow/python/python_to_arrow.cc
index e00f194..9ae66dc 100644
--- a/cpp/src/arrow/python/python_to_arrow.cc
+++ b/cpp/src/arrow/python/python_to_arrow.cc
@@ -41,11 +41,6 @@
 
 constexpr int32_t kMaxRecursionDepth = 100;
 
-extern "C" {
-PyObject* pyarrow_serialize_callback = NULL;
-PyObject* pyarrow_deserialize_callback = NULL;
-}
-
 namespace arrow {
 namespace py {
 
@@ -323,9 +318,10 @@ class DictBuilder {
   SequenceBuilder vals_;
 };
 
-Status CallCustomCallback(PyObject* callback, PyObject* elem, PyObject** result) {
+Status CallCustomCallback(PyObject* context, PyObject* method_name, PyObject* elem,
+                          PyObject** result) {
   *result = NULL;
-  if (!callback) {
+  if (context == Py_None) {
     std::stringstream ss;
     ScopedRef repr(PyObject_Repr(elem));
     RETURN_IF_PYERROR();
@@ -334,36 +330,38 @@ Status CallCustomCallback(PyObject* callback, PyObject* elem, PyObject**
result)
        << ": handler not registered";
     return Status::NotImplemented(ss.str());
   } else {
-    ScopedRef arglist(Py_BuildValue("(O)", elem));
-    *result = PyObject_CallObject(callback, arglist.get());
+    *result = PyObject_CallMethodObjArgs(context, method_name, elem, NULL);
     RETURN_IF_PYERROR();
   }
   return Status::OK();
 }
 
-void set_serialization_callbacks(PyObject* serialize_callback,
-                                 PyObject* deserialize_callback) {
-  pyarrow_serialize_callback = serialize_callback;
-  pyarrow_deserialize_callback = deserialize_callback;
-}
-
-Status CallCustomSerializationCallback(PyObject* elem, PyObject** serialized_object) {
-  RETURN_NOT_OK(CallCustomCallback(pyarrow_serialize_callback, elem, serialized_object));
+Status CallSerializeCallback(PyObject* context, PyObject* value,
+                             PyObject** serialized_object) {
+  ScopedRef method_name(PyUnicode_FromString("_serialize_callback"));
+  RETURN_NOT_OK(CallCustomCallback(context, method_name.get(), value, serialized_object));
   if (!PyDict_Check(*serialized_object)) {
     return Status::TypeError("serialization callback must return a valid dictionary");
   }
   return Status::OK();
 }
 
-Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
-                     std::shared_ptr<Array>* out, std::vector<PyObject*>* tensors_out);
+Status CallDeserializeCallback(PyObject* context, PyObject* value,
+                               PyObject** deserialized_object) {
+  ScopedRef method_name(PyUnicode_FromString("_deserialize_callback"));
+  return CallCustomCallback(context, method_name.get(), value, deserialized_object);
+}
+
+Status SerializeDict(PyObject* context, std::vector<PyObject*> dicts,
+                     int32_t recursion_depth, std::shared_ptr<Array>* out,
+                     std::vector<PyObject*>* tensors_out);
 
-Status SerializeArray(PyArrayObject* array, SequenceBuilder* builder,
+Status SerializeArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder,
                       std::vector<PyObject*>* subdicts,
                       std::vector<PyObject*>* tensors_out);
 
-Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth,
-                          std::shared_ptr<Array>* out,
+Status SerializeSequences(PyObject* context, std::vector<PyObject*> sequences,
+                          int32_t recursion_depth, std::shared_ptr<Array>* out,
                           std::vector<PyObject*>* tensors_out);
 
 Status AppendScalar(PyObject* obj, SequenceBuilder* builder) {
@@ -405,9 +403,9 @@ Status AppendScalar(PyObject* obj, SequenceBuilder* builder) {
   return builder->AppendInt64(value);
 }
 
-Status Append(PyObject* elem, SequenceBuilder* builder, std::vector<PyObject*>* sublists,
-              std::vector<PyObject*>* subtuples, std::vector<PyObject*>* subdicts,
-              std::vector<PyObject*>* tensors_out) {
+Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder,
+              std::vector<PyObject*>* sublists, std::vector<PyObject*>* subtuples,
+              std::vector<PyObject*>* subdicts, std::vector<PyObject*>* tensors_out)
{
   // The bool case must precede the int case (PyInt_Check passes for bools)
   if (PyBool_Check(elem)) {
     RETURN_NOT_OK(builder->AppendBool(elem == Py_True));
@@ -422,7 +420,7 @@ Status Append(PyObject* elem, SequenceBuilder* builder, std::vector<PyObject*>*
       // Attempt to serialize the object using the custom callback.
       PyObject* serialized_object;
       // The reference count of serialized_object will be decremented in SerializeDict
-      RETURN_NOT_OK(CallCustomSerializationCallback(elem, &serialized_object));
+      RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object));
       RETURN_NOT_OK(builder->AppendDict(PyDict_Size(serialized_object)));
       subdicts->push_back(serialized_object);
     }
@@ -462,7 +460,7 @@ Status Append(PyObject* elem, SequenceBuilder* builder, std::vector<PyObject*>*
   } else if (PyArray_IsScalar(elem, Generic)) {
     RETURN_NOT_OK(AppendScalar(elem, builder));
   } else if (PyArray_Check(elem)) {
-    RETURN_NOT_OK(SerializeArray(reinterpret_cast<PyArrayObject*>(elem), builder,
+    RETURN_NOT_OK(SerializeArray(context, reinterpret_cast<PyArrayObject*>(elem), builder,
                                  subdicts, tensors_out));
   } else if (elem == Py_None) {
     RETURN_NOT_OK(builder->AppendNone());
@@ -470,14 +468,14 @@ Status Append(PyObject* elem, SequenceBuilder* builder, std::vector<PyObject*>*
     // Attempt to serialize the object using the custom callback.
     PyObject* serialized_object;
     // The reference count of serialized_object will be decremented in SerializeDict
-    RETURN_NOT_OK(CallCustomSerializationCallback(elem, &serialized_object));
+    RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object));
     RETURN_NOT_OK(builder->AppendDict(PyDict_Size(serialized_object)));
     subdicts->push_back(serialized_object);
   }
   return Status::OK();
 }
 
-Status SerializeArray(PyArrayObject* array, SequenceBuilder* builder,
+Status SerializeArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder,
                       std::vector<PyObject*>* subdicts,
                       std::vector<PyObject*>* tensors_out) {
   int dtype = PyArray_TYPE(array);
@@ -499,8 +497,8 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder* builder,
     default: {
       PyObject* serialized_object;
       // The reference count of serialized_object will be decremented in SerializeDict
-      RETURN_NOT_OK(CallCustomSerializationCallback(reinterpret_cast<PyObject*>(array),
-                                                    &serialized_object));
+      RETURN_NOT_OK(CallSerializeCallback(context, reinterpret_cast<PyObject*>(array),
+                                          &serialized_object));
       RETURN_NOT_OK(builder->AppendDict(PyDict_Size(serialized_object)));
       subdicts->push_back(serialized_object);
     }
@@ -508,8 +506,8 @@ Status SerializeArray(PyArrayObject* array, SequenceBuilder* builder,
   return Status::OK();
 }
 
-Status SerializeSequences(std::vector<PyObject*> sequences, int32_t recursion_depth,
-                          std::shared_ptr<Array>* out,
+Status SerializeSequences(PyObject* context, std::vector<PyObject*> sequences,
+                          int32_t recursion_depth, std::shared_ptr<Array>* out,
                           std::vector<PyObject*>* tensors_out) {
   DCHECK(out);
   if (recursion_depth >= kMaxRecursionDepth) {
@@ -524,28 +522,31 @@ Status SerializeSequences(std::vector<PyObject*> sequences, int32_t
recursion_de
     RETURN_IF_PYERROR();
     ScopedRef item;
     while (item.reset(PyIter_Next(iterator.get())), item.get()) {
-      RETURN_NOT_OK(
-          Append(item.get(), &builder, &sublists, &subtuples, &subdicts,
tensors_out));
+      RETURN_NOT_OK(Append(context, item.get(), &builder, &sublists, &subtuples,
+                           &subdicts, tensors_out));
     }
   }
   std::shared_ptr<Array> list;
   if (sublists.size() > 0) {
-    RETURN_NOT_OK(SerializeSequences(sublists, recursion_depth + 1, &list, tensors_out));
+    RETURN_NOT_OK(
+        SerializeSequences(context, sublists, recursion_depth + 1, &list, tensors_out));
   }
   std::shared_ptr<Array> tuple;
   if (subtuples.size() > 0) {
     RETURN_NOT_OK(
-        SerializeSequences(subtuples, recursion_depth + 1, &tuple, tensors_out));
+        SerializeSequences(context, subtuples, recursion_depth + 1, &tuple, tensors_out));
   }
   std::shared_ptr<Array> dict;
   if (subdicts.size() > 0) {
-    RETURN_NOT_OK(SerializeDict(subdicts, recursion_depth + 1, &dict, tensors_out));
+    RETURN_NOT_OK(
+        SerializeDict(context, subdicts, recursion_depth + 1, &dict, tensors_out));
   }
   return builder.Finish(list.get(), tuple.get(), dict.get(), out);
 }
 
-Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
-                     std::shared_ptr<Array>* out, std::vector<PyObject*>* tensors_out)
{
+Status SerializeDict(PyObject* context, std::vector<PyObject*> dicts,
+                     int32_t recursion_depth, std::shared_ptr<Array>* out,
+                     std::vector<PyObject*>* tensors_out) {
   DictBuilder result;
   if (recursion_depth >= kMaxRecursionDepth) {
     return Status::NotImplemented(
@@ -557,37 +558,37 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
     PyObject *key, *value;
     Py_ssize_t pos = 0;
     while (PyDict_Next(dict, &pos, &key, &value)) {
-      RETURN_NOT_OK(
-          Append(key, &result.keys(), &dummy, &key_tuples, &key_dicts, tensors_out));
-      DCHECK_EQ(dummy.size(), 0);
-      RETURN_NOT_OK(Append(value, &result.vals(), &val_lists, &val_tuples, &val_dicts,
+      RETURN_NOT_OK(Append(context, key, &result.keys(), &dummy, &key_tuples,
&key_dicts,
                            tensors_out));
+      DCHECK_EQ(dummy.size(), 0);
+      RETURN_NOT_OK(Append(context, value, &result.vals(), &val_lists, &val_tuples,
+                           &val_dicts, tensors_out));
     }
   }
   std::shared_ptr<Array> key_tuples_arr;
   if (key_tuples.size() > 0) {
-    RETURN_NOT_OK(SerializeSequences(key_tuples, recursion_depth + 1, &key_tuples_arr,
-                                     tensors_out));
+    RETURN_NOT_OK(SerializeSequences(context, key_tuples, recursion_depth + 1,
+                                     &key_tuples_arr, tensors_out));
   }
   std::shared_ptr<Array> key_dicts_arr;
   if (key_dicts.size() > 0) {
-    RETURN_NOT_OK(
-        SerializeDict(key_dicts, recursion_depth + 1, &key_dicts_arr, tensors_out));
+    RETURN_NOT_OK(SerializeDict(context, key_dicts, recursion_depth + 1, &key_dicts_arr,
+                                tensors_out));
   }
   std::shared_ptr<Array> val_list_arr;
   if (val_lists.size() > 0) {
-    RETURN_NOT_OK(
-        SerializeSequences(val_lists, recursion_depth + 1, &val_list_arr, tensors_out));
+    RETURN_NOT_OK(SerializeSequences(context, val_lists, recursion_depth + 1,
+                                     &val_list_arr, tensors_out));
   }
   std::shared_ptr<Array> val_tuples_arr;
   if (val_tuples.size() > 0) {
-    RETURN_NOT_OK(SerializeSequences(val_tuples, recursion_depth + 1, &val_tuples_arr,
-                                     tensors_out));
+    RETURN_NOT_OK(SerializeSequences(context, val_tuples, recursion_depth + 1,
+                                     &val_tuples_arr, tensors_out));
   }
   std::shared_ptr<Array> val_dict_arr;
   if (val_dicts.size() > 0) {
-    RETURN_NOT_OK(
-        SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr, tensors_out));
+    RETURN_NOT_OK(SerializeDict(context, val_dicts, recursion_depth + 1, &val_dict_arr,
+                                tensors_out));
   }
   RETURN_NOT_OK(result.Finish(key_tuples_arr.get(), key_dicts_arr.get(),
                               val_list_arr.get(), val_tuples_arr.get(),
@@ -601,7 +602,7 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
     if (PyDict_Contains(dict, py_type)) {
       // If the dictionary contains the key "_pytype_", then the user has to
       // have registered a callback.
-      if (pyarrow_serialize_callback == nullptr) {
+      if (context == Py_None) {
         return Status::Invalid("No serialization callback set");
       }
       Py_XDECREF(dict);
@@ -617,12 +618,12 @@ std::shared_ptr<RecordBatch> MakeBatch(std::shared_ptr<Array>
data) {
   return std::shared_ptr<RecordBatch>(new RecordBatch(schema, data->length(), {data}));
 }
 
-Status SerializeObject(PyObject* sequence, SerializedPyObject* out) {
+Status SerializeObject(PyObject* context, PyObject* sequence, SerializedPyObject* out) {
   PyAcquireGIL lock;
   std::vector<PyObject*> sequences = {sequence};
   std::shared_ptr<Array> array;
   std::vector<PyObject*> py_tensors;
-  RETURN_NOT_OK(SerializeSequences(sequences, 0, &array, &py_tensors));
+  RETURN_NOT_OK(SerializeSequences(context, sequences, 0, &array, &py_tensors));
   out->batch = MakeBatch(array);
   for (const auto& py_tensor : py_tensors) {
     std::shared_ptr<Tensor> arrow_tensor;

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/cpp/src/arrow/python/python_to_arrow.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/python/python_to_arrow.h b/cpp/src/arrow/python/python_to_arrow.h
index 8ac0396..2c80e5d 100644
--- a/cpp/src/arrow/python/python_to_arrow.h
+++ b/cpp/src/arrow/python/python_to_arrow.h
@@ -44,18 +44,12 @@ struct ARROW_EXPORT SerializedPyObject {
   std::vector<std::shared_ptr<Tensor>> tensors;
 };
 
-/// \brief Register callback functions to perform conversions to or from other
-/// Python representations en route to/from deserialization
-///
-/// \param[in] serialize_callback a Python callable
-/// \param[in] deserialize_callback a Python callable
-///
-/// Analogous to Python custom picklers / unpicklers
-ARROW_EXPORT
-void set_serialization_callbacks(PyObject* serialize_callback,
-                                 PyObject* deserialize_callback);
-
 /// \brief Serialize Python sequence as a RecordBatch plus
+/// \param[in] context Serialization context which contains custom serialization
+/// and deserialization callbacks. Can be any Python object with a
+/// _serialize_callback method for serialization and a _deserialize_callback
+/// method for deserialization. If context is None, no custom serialization
+/// will be attempted.
 /// \param[in] sequence a Python sequence object to serialize to Arrow data
 /// structures
 /// \param[out] out the serialized representation
@@ -63,7 +57,7 @@ void set_serialization_callbacks(PyObject* serialize_callback,
 ///
 /// Release GIL before calling
 ARROW_EXPORT
-Status SerializeObject(PyObject* sequence, SerializedPyObject* out);
+Status SerializeObject(PyObject* context, PyObject* sequence, SerializedPyObject* out);
 
 /// \brief Write serialized Python object to OutputStream
 /// \param[in] object a serialized Python object to write out

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/python/doc/source/api.rst
----------------------------------------------------------------------
diff --git a/python/doc/source/api.rst b/python/doc/source/api.rst
index 846af4c..4761c7f 100644
--- a/python/doc/source/api.rst
+++ b/python/doc/source/api.rst
@@ -207,6 +207,7 @@ Interprocess Communication and Serialization
    deserialize_from
    read_serialized
    SerializedPyObject
+   SerializationContext
 
 .. _api.memory_pool:
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/python/pyarrow/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index 985e262..ecdcfcd 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -91,7 +91,7 @@ from pyarrow.lib import (ArrowException,
 # Serialization
 from pyarrow.lib import (deserialize_from, deserialize,
                          serialize, serialize_to, read_serialized,
-                         SerializedPyObject,
+                         SerializedPyObject, SerializationContext,
                          SerializationException, DeserializationException)
 
 from pyarrow.filesystem import FileSystem, LocalFileSystem

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/python/pyarrow/includes/libarrow.pxd
----------------------------------------------------------------------
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index fcf27da..98eda8b 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -797,20 +797,19 @@ cdef extern from "arrow/python/api.h" namespace 'arrow::py' nogil:
         shared_ptr[CRecordBatch] batch
         vector[shared_ptr[CTensor]] tensors
 
-    CStatus SerializeObject(object sequence, CSerializedPyObject* out)
+    CStatus SerializeObject(object context, object sequence,
+                            CSerializedPyObject* out)
 
     CStatus WriteSerializedObject(const CSerializedPyObject& obj,
                                   OutputStream* dst)
 
-    CStatus DeserializeObject(const CSerializedPyObject& obj,
+    CStatus DeserializeObject(object context,
+                              const CSerializedPyObject& obj,
                               PyObject* base, PyObject** out)
 
     CStatus ReadSerializedObject(RandomAccessFile* src,
                                  CSerializedPyObject* out)
 
-    void set_serialization_callbacks(object serialize_callback,
-                                     object deserialize_callback)
-
 
 cdef extern from 'arrow/python/init.h':
     int arrow_init_numpy() except -1

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/python/pyarrow/serialization.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/serialization.pxi b/python/pyarrow/serialization.pxi
index 062121f..dc9bdaf 100644
--- a/python/pyarrow/serialization.pxi
+++ b/python/pyarrow/serialization.pxi
@@ -43,93 +43,98 @@ class DeserializationException(Exception):
         self.type_id = type_id
 
 
-# Types with special serialization handlers
-type_to_type_id = dict()
-whitelisted_types = dict()
-types_to_pickle = set()
-custom_serializers = dict()
-custom_deserializers = dict()
-
-
-def register_type(type, type_id, pickle=False,
-                  custom_serializer=None, custom_deserializer=None):
-    """Add type to the list of types we can serialize.
-
-    Parameters
-    ----------
-    type :type
-        The type that we can serialize.
-    type_id : bytes
-        A string of bytes used to identify the type.
-    pickle : bool
-        True if the serialization should be done with pickle.
-        False if it should be done efficiently with Arrow.
-    custom_serializer : callable
-        This argument is optional, but can be provided to
-        serialize objects of the class in a particular way.
-    custom_deserializer : callable
-        This argument is optional, but can be provided to
-        deserialize objects of the class in a particular way.
-    """
-    type_to_type_id[type] = type_id
-    whitelisted_types[type_id] = type
-    if pickle:
-        types_to_pickle.add(type_id)
-    if custom_serializer is not None:
-        custom_serializers[type_id] = custom_serializer
-        custom_deserializers[type_id] = custom_deserializer
-
-
-def _serialization_callback(obj):
-    if type(obj) not in type_to_type_id:
-        raise SerializationException("pyarrow does not know how to "
-                                     "serialize objects of type {}."
-                                     .format(type(obj)),
-                                     obj)
-    type_id = type_to_type_id[type(obj)]
-    if type_id in types_to_pickle:
-        serialized_obj = {"data": pickle.dumps(obj), "pickle": True}
-    elif type_id in custom_serializers:
-        serialized_obj = {"data": custom_serializers[type_id](obj)}
-    else:
-        if is_named_tuple(type(obj)):
-            serialized_obj = {}
-            serialized_obj["_pa_getnewargs_"] = obj.__getnewargs__()
-        elif hasattr(obj, "__dict__"):
-            serialized_obj = obj.__dict__
-        else:
-            raise SerializationException("We do not know how to serialize "
-                                         "the object '{}'".format(obj), obj)
-    return dict(serialized_obj, **{"_pytype_": type_id})
-
-
-def _deserialization_callback(serialized_obj):
-    type_id = serialized_obj["_pytype_"]
-
-    if "pickle" in serialized_obj:
-        # The object was pickled, so unpickle it.
-        obj = pickle.loads(serialized_obj["data"])
-    else:
-        assert type_id not in types_to_pickle
-        if type_id not in whitelisted_types:
-            raise "error"
-        type = whitelisted_types[type_id]
-        if type_id in custom_deserializers:
-            obj = custom_deserializers[type_id](serialized_obj["data"])
+cdef class SerializationContext:
+    cdef:
+        object type_to_type_id
+        object whitelisted_types
+        object types_to_pickle
+        object custom_serializers
+        object custom_deserializers
+
+    def __init__(self):
+        # Types with special serialization handlers
+        self.type_to_type_id = dict()
+        self.whitelisted_types = dict()
+        self.types_to_pickle = set()
+        self.custom_serializers = dict()
+        self.custom_deserializers = dict()
+
+    def register_type(self, type_, type_id, pickle=False,
+                      custom_serializer=None, custom_deserializer=None):
+        """EXPERIMENTAL: Add type to the list of types we can serialize.
+
+        Parameters
+        ----------
+        type_ : TypeType
+            The type that we can serialize.
+        type_id : bytes
+            A string of bytes used to identify the type.
+        pickle : bool
+            True if the serialization should be done with pickle.
+            False if it should be done efficiently with Arrow.
+        custom_serializer : callable
+            This argument is optional, but can be provided to
+            serialize objects of the class in a particular way.
+        custom_deserializer : callable
+            This argument is optional, but can be provided to
+            deserialize objects of the class in a particular way.
+        """
+        self.type_to_type_id[type_] = type_id
+        self.whitelisted_types[type_id] = type_
+        if pickle:
+            self.types_to_pickle.add(type_id)
+        if custom_serializer is not None:
+            self.custom_serializers[type_id] = custom_serializer
+            self.custom_deserializers[type_id] = custom_deserializer
+
+    def _serialize_callback(self, obj):
+        if type(obj) not in self.type_to_type_id:
+            raise SerializationException("pyarrow does not know how to "
+                                         "serialize objects of type {}."
+                                         .format(type(obj)),
+                                         obj)
+        type_id = self.type_to_type_id[type(obj)]
+        if type_id in self.types_to_pickle:
+            serialized_obj = {"data": pickle.dumps(obj), "pickle": True}
+        elif type_id in self.custom_serializers:
+            serialized_obj = {"data": self.custom_serializers[type_id](obj)}
         else:
-            # In this case, serialized_obj should just be the __dict__ field.
-            if "_pa_getnewargs_" in serialized_obj:
-                obj = type.__new__(type, *serialized_obj["_pa_getnewargs_"])
+            if is_named_tuple(type(obj)):
+                serialized_obj = {}
+                serialized_obj["_pa_getnewargs_"] = obj.__getnewargs__()
+            elif hasattr(obj, "__dict__"):
+                serialized_obj = obj.__dict__
             else:
-                obj = type.__new__(type)
-                serialized_obj.pop("_pytype_")
-                obj.__dict__.update(serialized_obj)
-    return obj
+                msg = "We do not know how to serialize " \
+                      "the object '{}'".format(obj)
+                raise SerializationException(msg, obj)
+        return dict(serialized_obj, **{"_pytype_": type_id})
 
+    def _deserialize_callback(self, serialized_obj):
+        type_id = serialized_obj["_pytype_"]
 
-set_serialization_callbacks(_serialization_callback,
-                            _deserialization_callback)
-
+        if "pickle" in serialized_obj:
+            # The object was pickled, so unpickle it.
+            obj = pickle.loads(serialized_obj["data"])
+        else:
+            assert type_id not in self.types_to_pickle
+            if type_id not in self.whitelisted_types:
+                raise "error"
+            type_ = self.whitelisted_types[type_id]
+            if type_id in self.custom_deserializers:
+                obj = self.custom_deserializers[type_id](
+                    serialized_obj["data"])
+            else:
+                # In this case, serialized_obj should just be
+                # the __dict__ field.
+                if "_pa_getnewargs_" in serialized_obj:
+                    obj = type_.__new__(
+                        type_, *serialized_obj["_pa_getnewargs_"])
+                else:
+                    obj = type_.__new__(type_)
+                    serialized_obj.pop("_pytype_")
+                    obj.__dict__.update(serialized_obj)
+        return obj
 
 cdef class SerializedPyObject:
     """
@@ -162,15 +167,15 @@ cdef class SerializedPyObject:
         with nogil:
             check_status(WriteSerializedObject(self.data, stream))
 
-    def deserialize(self):
+    def deserialize(self, SerializationContext context=None):
         """
         Convert back to Python object
         """
         cdef PyObject* result
 
         with nogil:
-            check_status(DeserializeObject(self.data, <PyObject*> self.base,
-                                           &result))
+            check_status(DeserializeObject(context, self.data,
+                                           <PyObject*> self.base, &result))
 
         # PyObject_to_object is necessary to avoid a memory leak;
         # also unpack the list the object was wrapped in in serialize
@@ -185,13 +190,15 @@ cdef class SerializedPyObject:
         return sink.get_result()
 
 
-def serialize(object value):
+def serialize(object value, SerializationContext context=None):
     """EXPERIMENTAL: Serialize a Python sequence
 
     Parameters
     ----------
     value: object
         Python object for the sequence that is to be serialized.
+    context : SerializationContext
+        Custom serialization and deserialization context
 
     Returns
     -------
@@ -200,11 +207,11 @@ def serialize(object value):
     cdef SerializedPyObject serialized = SerializedPyObject()
     wrapped_value = [value]
     with nogil:
-        check_status(SerializeObject(wrapped_value, &serialized.data))
+        check_status(SerializeObject(context, wrapped_value, &serialized.data))
     return serialized
 
 
-def serialize_to(object value, sink):
+def serialize_to(object value, sink, SerializationContext context=None):
     """EXPERIMENTAL: Serialize a Python sequence to a file.
 
     Parameters
@@ -213,8 +220,10 @@ def serialize_to(object value, sink):
         Python object for the sequence that is to be serialized.
     sink: NativeFile or file-like
         File the sequence will be written to.
+    context : SerializationContext
+        Custom serialization and deserialization context
     """
-    serialized = serialize(value)
+    serialized = serialize(value, context)
     serialized.write_to(sink)
 
 
@@ -244,7 +253,7 @@ def read_serialized(source, base=None):
     return serialized
 
 
-def deserialize_from(source, object base):
+def deserialize_from(source, object base, SerializationContext context=None):
     """EXPERIMENTAL: Deserialize a Python sequence from a file.
 
     Parameters
@@ -254,6 +263,8 @@ def deserialize_from(source, object base):
     base: object
         This object will be the base object of all the numpy arrays
         contained in the sequence.
+    context : SerializationContext
+        Custom serialization and deserialization context
 
     Returns
     -------
@@ -261,10 +272,10 @@ def deserialize_from(source, object base):
         Python object for the deserialized sequence.
     """
     serialized = read_serialized(source, base=base)
-    return serialized.deserialize()
+    return serialized.deserialize(context)
 
 
-def deserialize(obj):
+def deserialize(obj, SerializationContext context=None):
     """
     EXPERIMENTAL: Deserialize Python object from Buffer or other Python object
     supporting the buffer protocol
@@ -272,10 +283,12 @@ def deserialize(obj):
     Parameters
     ----------
     obj : pyarrow.Buffer or Python object supporting buffer protocol
+    context : SerializationContext
+        Custom serialization and deserialization context
 
     Returns
     -------
     deserialized : object
     """
     source = BufferReader(obj)
-    return deserialize_from(source, obj)
+    return deserialize_from(source, obj, context)

http://git-wip-us.apache.org/repos/asf/arrow/blob/441f9659/python/pyarrow/tests/test_serialization.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index b2aa4af..d922576 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -81,18 +81,6 @@ def assert_equal(obj1, obj2):
                                                                        obj2)
 
 
-def array_custom_serializer(obj):
-    return obj.tolist(), obj.dtype.str
-
-
-def array_custom_deserializer(serialized_obj):
-    return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
-
-
-pa.lib.register_type(np.ndarray, 20 * b"\x00", pickle=False,
-                     custom_serializer=array_custom_serializer,
-                     custom_deserializer=array_custom_deserializer)
-
 if sys.version_info >= (3, 0):
     long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
 else:
@@ -155,6 +143,11 @@ class SubQux(Qux):
         Qux.__init__(self)
 
 
+class SubQuxPickle(Qux):
+    def __init__(self):
+        Qux.__init__(self)
+
+
 class CustomError(Exception):
     pass
 
@@ -165,41 +158,61 @@ NamedTupleExample = namedtuple("Example",
 
 
 CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22),
-                  Foo(), Bar(), Baz(), Qux(), SubQux(),
+                  Foo(), Bar(), Baz(), Qux(), SubQux(), SubQuxPickle(),
                   NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])]
 
-pa.lib.register_type(Foo, 20 * b"\x01")
-pa.lib.register_type(Bar, 20 * b"\x02")
-pa.lib.register_type(Baz, 20 * b"\x03")
-pa.lib.register_type(Qux, 20 * b"\x04")
-pa.lib.register_type(SubQux, 20 * b"\x05")
-pa.lib.register_type(Exception, 20 * b"\x06")
-pa.lib.register_type(CustomError, 20 * b"\x07")
-pa.lib.register_type(Point, 20 * b"\x08")
-pa.lib.register_type(NamedTupleExample, 20 * b"\x09")
-
-# TODO(pcm): This is currently a workaround until arrow supports
-# arbitrary precision integers. This is only called on long integers,
-# see the associated case in the append method in python_to_arrow.cc
-pa.lib.register_type(int, 20 * b"\x10", pickle=False,
-                     custom_serializer=lambda obj: str(obj),
-                     custom_deserializer=(
-                         lambda serialized_obj: int(serialized_obj)))
-
-
-if (sys.version_info < (3, 0)):
-    deserializer = (
-        lambda serialized_obj: long(serialized_obj))  # noqa: E501,F821
-    pa.lib.register_type(long, 20 * b"\x11", pickle=False,  # noqa: E501,F821
-                         custom_serializer=lambda obj: str(obj),
-                         custom_deserializer=deserializer)
+
+def make_serialization_context():
+
+    def array_custom_serializer(obj):
+        return obj.tolist(), obj.dtype.str
+
+    def array_custom_deserializer(serialized_obj):
+        return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
+
+    context = pa.SerializationContext()
+
+    context.register_type(np.ndarray, 20 * b"\x00",
+                          custom_serializer=array_custom_serializer,
+                          custom_deserializer=array_custom_deserializer)
+
+    context.register_type(Foo, 20 * b"\x01")
+    context.register_type(Bar, 20 * b"\x02")
+    context.register_type(Baz, 20 * b"\x03")
+    context.register_type(Qux, 20 * b"\x04")
+    context.register_type(SubQux, 20 * b"\x05")
+    context.register_type(SubQuxPickle, 20 * b"\x05", pickle=True)
+    context.register_type(Exception, 20 * b"\x06")
+    context.register_type(CustomError, 20 * b"\x07")
+    context.register_type(Point, 20 * b"\x08")
+    context.register_type(NamedTupleExample, 20 * b"\x09")
+
+    # TODO(pcm): This is currently a workaround until arrow supports
+    # arbitrary precision integers. This is only called on long integers,
+    # see the associated case in the append method in python_to_arrow.cc
+    context.register_type(int, 20 * b"\x10", pickle=False,
+                          custom_serializer=lambda obj: str(obj),
+                          custom_deserializer=(
+                              lambda serialized_obj: int(serialized_obj)))
+
+    if (sys.version_info < (3, 0)):
+        deserializer = (
+            lambda serialized_obj: long(serialized_obj))  # noqa: E501,F821
+        context.register_type(long, 20 * b"\x11", pickle=False,  # noqa: E501,F821
+                              custom_serializer=lambda obj: str(obj),
+                              custom_deserializer=deserializer)
+
+    return context
+
+
+serialization_context = make_serialization_context()
 
 
 def serialization_roundtrip(value, f):
     f.seek(0)
-    pa.serialize_to(value, f)
+    pa.serialize_to(value, f, serialization_context)
     f.seek(0)
-    result = pa.deserialize_from(f, None)
+    result = pa.deserialize_from(f, None, serialization_context)
     assert_equal(value, result)
 
 


Mime
View raw message