spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-13697] [PYSPARK] Fix the missing module name of TransformFunctionSerializer.loads
Date Sun, 06 Mar 2016 16:57:27 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 f521c4470 -> 9ecd33e96


[SPARK-13697] [PYSPARK] Fix the missing module name of TransformFunctionSerializer.loads

## What changes were proposed in this pull request?

Set the function's module name to `__main__` if it's missing in `TransformFunctionSerializer.loads`.

## How was this patch tested?

Manually test in the shell.

Before this patch:
```
>>> from pyspark.streaming import StreamingContext
>>> from pyspark.streaming.util import TransformFunction
>>> ssc = StreamingContext(sc, 1)
>>> func = TransformFunction(sc, lambda x: x, sc.serializer)
>>> func.rdd_wrapper(lambda x: x)
TransformFunction(<function <lambda> at 0x106ac8b18>)
>>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func,
func.deserializers)))
>>> func2 = ssc._transformerSerializer.loads(bytes)
>>> print(func2.func.__module__)
None
>>> print(func2.rdd_wrap_func.__module__)
None
>>>
```
After this patch:
```
>>> from pyspark.streaming import StreamingContext
>>> from pyspark.streaming.util import TransformFunction
>>> ssc = StreamingContext(sc, 1)
>>> func = TransformFunction(sc, lambda x: x, sc.serializer)
>>> func.rdd_wrapper(lambda x: x)
TransformFunction(<function <lambda> at 0x108bf1b90>)
>>> bytes = bytearray(ssc._transformerSerializer.serializer.dumps((func.func, func.rdd_wrap_func,
func.deserializers)))
>>> func2 = ssc._transformerSerializer.loads(bytes)
>>> print(func2.func.__module__)
__main__
>>> print(func2.rdd_wrap_func.__module__)
__main__
>>>
```

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #11535 from zsxwing/loads-module.

(cherry picked from commit ee913e6e2d58dfac20f3f06ff306081bd0e48066)
Signed-off-by: Davies Liu <davies.liu@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9ecd33e9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9ecd33e9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9ecd33e9

Branch: refs/heads/branch-1.5
Commit: 9ecd33e96139a699d0ef6e76b3bab62021c62756
Parents: f521c44
Author: Shixiong Zhu <shixiong@databricks.com>
Authored: Sun Mar 6 08:57:01 2016 -0800
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Sun Mar 6 08:57:23 2016 -0800

----------------------------------------------------------------------
 python/pyspark/cloudpickle.py | 4 +++-
 python/pyspark/tests.py       | 6 ++++++
 2 files changed, 9 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9ecd33e9/python/pyspark/cloudpickle.py
----------------------------------------------------------------------
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 95b3abc..e56e22a 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -241,6 +241,7 @@ class CloudPickler(Pickler):
         save(f_globals)
         save(defaults)
         save(dct)
+        save(func.__module__)
         write(pickle.TUPLE)
         write(pickle.REDUCE)  # applies _fill_function on the tuple
 
@@ -698,13 +699,14 @@ def _genpartial(func, args, kwds):
     return partial(func, *args, **kwds)
 
 
-def _fill_function(func, globals, defaults, dict):
+def _fill_function(func, globals, defaults, dict, module):
     """ Fills in the rest of function data into the skeleton function object
         that were created via _make_skel_func().
          """
     func.__globals__.update(globals)
     func.__defaults__ = defaults
     func.__dict__ = dict
+    func.__module__ = module
 
     return func
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9ecd33e9/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 647504c..d323136 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -235,6 +235,12 @@ class SerializationTestCase(unittest.TestCase):
         getter2 = ser.loads(ser.dumps(getter))
         self.assertEqual(getter(d), getter2(d))
 
+    def test_function_module_name(self):
+        ser = CloudPickleSerializer()
+        func = lambda x: x
+        func2 = ser.loads(ser.dumps(func))
+        self.assertEqual(func.__module__, func2.__module__)
+
     def test_attrgetter(self):
         from operator import attrgetter
         ser = CloudPickleSerializer()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message