spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject [3/7] git commit: Add custom serializer support to PySpark.
Date Wed, 27 Nov 2013 04:56:00 GMT
Add custom serializer support to PySpark.

For now, this only adds MarshalSerializer, but it lays the groundwork
for other supporting custom serializers.  Many of these mechanisms
can also be used to support deserialization of different data formats
sent by Java, such as data encoded by MsgPack.

This also fixes a bug in SparkContext.union().


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/cbb7f04a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/cbb7f04a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/cbb7f04a

Branch: refs/heads/master
Commit: cbb7f04aef2220ece93dea9f3fa98b5db5f270d6
Parents: 7d68a81
Author: Josh Rosen <joshrosen@apache.org>
Authored: Tue Nov 5 17:52:39 2013 -0800
Committer: Josh Rosen <joshrosen@apache.org>
Committed: Sun Nov 10 16:45:38 2013 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |  23 +-
 python/epydoc.conf                              |   2 +-
 python/pyspark/accumulators.py                  |   6 +-
 python/pyspark/context.py                       |  61 +++-
 python/pyspark/rdd.py                           |  86 ++---
 python/pyspark/serializers.py                   | 310 +++++++++++++++----
 python/pyspark/tests.py                         |   3 +-
 python/pyspark/worker.py                        |  41 ++-
 python/run-tests                                |   1 +
 9 files changed, 363 insertions(+), 170 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index eb0b0db..ef9bf4d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -221,18 +221,6 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
-    val s = elem.getBytes("UTF-8")
-    val length = 2 + 1 + 4 + s.length + 1
-    dOut.writeInt(length)
-    dOut.writeByte(Pickle.PROTO)
-    dOut.writeByte(Pickle.TWO)
-    dOut.write(Pickle.BINUNICODE)
-    dOut.writeInt(Integer.reverseBytes(s.length))
-    dOut.write(s)
-    dOut.writeByte(Pickle.STOP)
-  }
-
   def writeToStream(elem: Any, dataOut: DataOutputStream) {
     elem match {
       case bytes: Array[Byte] =>
@@ -244,9 +232,7 @@ private[spark] object PythonRDD {
         dataOut.writeInt(pair._2.length)
         dataOut.write(pair._2)
       case str: String =>
-        // Until we've implemented full custom serializer support, we need to return
-        // strings as Pickles to properly support union() and cartesian():
-        writeStringAsPickle(str, dataOut)
+        dataOut.writeUTF(str)
       case other =>
         throw new SparkException("Unexpected element type " + other.getClass)
     }
@@ -271,13 +257,6 @@ private[spark] object PythonRDD {
   }
 }
 
-private object Pickle {
-  val PROTO: Byte = 0x80.toByte
-  val TWO: Byte = 0x02.toByte
-  val BINUNICODE: Byte = 'X'
-  val STOP: Byte = '.'
-}
-
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte],
String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/epydoc.conf
----------------------------------------------------------------------
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002..0b42e72 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
 
 private: no
 
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
          pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
          pyspark.rddsampler pyspark.daemon

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/accumulators.py
----------------------------------------------------------------------
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d966..2204e9c 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
 import SocketServer
 import threading
 from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
 
 
+pickleSer = PickleSerializer()
+
 # Holds accumulators registered on the current machine, keyed by ID. This is then used to
send
 # the local accumulator updates back to the driver program at the end of a task.
 _accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
         from pyspark.accumulators import _accumulatorRegistry
         num_updates = read_int(self.rfile)
         for _ in range(num_updates):
-            (aid, update) = load_pickle(read_with_length(self.rfile))
+            (aid, update) = pickleSer._read_with_length(self.rfile)
             _accumulatorRegistry[aid] += update
         # Write a byte in acknowledgement
         self.wfile.write(struct.pack("!b", 1))

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 0fec1a6..6bb1c6c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 
@@ -51,7 +51,7 @@ class SparkContext(object):
 
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
-        environment=None, batchSize=1024):
+        environment=None, batchSize=1024, serializer=PickleSerializer()):
         """
         Create a new SparkContext.
 
@@ -67,6 +67,7 @@ class SparkContext(object):
         @param batchSize: The number of Python objects represented as a single
                Java object.  Set 1 to disable batching or -1 to use an
                unlimited batch size.
+        @param serializer: The serializer for RDDs.
 
 
         >>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
         self.jobName = jobName
         self.sparkHome = sparkHome or None # None becomes null in Py4J
         self.environment = environment or {}
-        self.batchSize = batchSize  # -1 represents a unlimited batch size
+        self._batchSize = batchSize  # -1 represents an unlimited batch size
+        self._unbatched_serializer = serializer
+        if batchSize == 1:
+            self.serializer = self._unbatched_serializer
+        else:
+            self.serializer = BatchedSerializer(self._unbatched_serializer,
+                                                batchSize)
 
         # Create the Java SparkContext through Py4J
         empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -184,15 +191,17 @@ class SparkContext(object):
         # Make sure we distribute data evenly if it's smaller than self.batchSize
         if "__len__" not in dir(c):
             c = list(c)    # Make it a list so we can compute its length
-        batchSize = min(len(c) // numSlices, self.batchSize)
+        batchSize = min(len(c) // numSlices, self._batchSize)
         if batchSize > 1:
-            c = batched(c, batchSize)
-        for x in c:
-            write_with_length(dump_pickle(x), tempFile)
+            serializer = BatchedSerializer(self._unbatched_serializer,
+                                           batchSize)
+        else:
+            serializer = self._unbatched_serializer
+        serializer.dump_stream(c, tempFile)
         tempFile.close()
         readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
         jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, serializer)
 
     def textFile(self, name, minSplits=None):
         """
@@ -201,21 +210,39 @@ class SparkContext(object):
         RDD of Strings.
         """
         minSplits = minSplits or min(self.defaultParallelism, 2)
-        jrdd = self._jsc.textFile(name, minSplits)
-        return RDD(jrdd, self)
+        return RDD(self._jsc.textFile(name, minSplits), self,
+                   MUTF8Deserializer())
 
-    def _checkpointFile(self, name):
+    def _checkpointFile(self, name, input_deserializer):
         jrdd = self._jsc.checkpointFile(name)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, input_deserializer)
 
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
+
+        This supports unions() of RDDs with different serialized formats,
+        although this forces them to be reserialized using the default
+        serializer:
+
+        >>> path = os.path.join(tempdir, "union-text.txt")
+        >>> with open(path, "w") as testFile:
+        ...    testFile.write("Hello")
+        >>> textFile = sc.textFile(path)
+        >>> textFile.collect()
+        [u'Hello']
+        >>> parallelized = sc.parallelize(["World!"])
+        >>> sorted(sc.union([textFile, parallelized]).collect())
+        [u'Hello', 'World!']
         """
+        first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+        if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+            rdds = [x._reserialize() for x in rdds]
         first = rdds[0]._jrdd
         rest = [x._jrdd for x in rdds[1:]]
-        rest = ListConverter().convert(rest, self.gateway._gateway_client)
-        return RDD(self._jsc.union(first, rest), self)
+        rest = ListConverter().convert(rest, self._gateway._gateway_client)
+        return RDD(self._jsc.union(first, rest), self,
+                   rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
         """
@@ -223,7 +250,9 @@ class SparkContext(object):
         object for reading it in distributed functions. The variable will be
         sent to each cluster only once.
         """
-        jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+        pickleSer = PickleSerializer()
+        pickled = pickleSer._dumps(value)
+        jbroadcast = self._jsc.broadcast(bytearray(pickled))
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
 
@@ -235,7 +264,7 @@ class SparkContext(object):
         and floating-point numbers if you do not provide one. For other types,
         a custom AccumulatorParam can be used.
         """
-        if accum_param == None:
+        if accum_param is None:
             if isinstance(value, int):
                 accum_param = accumulators.INT_ACCUMULATOR_PARAM
             elif isinstance(value, float):

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d3c4d13..6691c30 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
 import operator
 import os
 import sys
@@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile
 from threading import Thread
 
 from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
-    read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+    BatchedSerializer, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -48,13 +48,12 @@ class RDD(object):
     operated on in parallel.
     """
 
-    def __init__(self, jrdd, ctx):
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
         self._jrdd = jrdd
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = ctx
-        self._partitionFunc = None
-        self._stage_input_is_pairs = False
+        self._jrdd_deserializer = jrdd_deserializer
 
     @property
     def context(self):
@@ -248,7 +247,23 @@ class RDD(object):
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
         """
-        return RDD(self._jrdd.union(other._jrdd), self.ctx)
+        if self._jrdd_deserializer == other._jrdd_deserializer:
+            rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+                      self._jrdd_deserializer)
+            return rdd
+        else:
+            # These RDDs contain data in different serialized formats, so we
+            # must normalize them to the default serializer.
+            self_copy = self._reserialize()
+            other_copy = other._reserialize()
+            return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+                       self.ctx.serializer)
+
+    def _reserialize(self):
+        if self._jrdd_deserializer == self.ctx.serializer:
+            return self
+        else:
+            return self.map(lambda x: x, preservesPartitioning=True)
 
     def __add__(self, other):
         """
@@ -335,18 +350,9 @@ class RDD(object):
         [(1, 1), (1, 2), (2, 1), (2, 2)]
         """
         # Due to batching, we can't use the Java cartesian method.
-        java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
-        def unpack_batches(pair):
-            (x, y) = pair
-            if type(x) == Batch or type(y) == Batch:
-                xs = x.items if type(x) == Batch else [x]
-                ys = y.items if type(y) == Batch else [y]
-                for pair in product(xs, ys):
-                    yield pair
-            else:
-                yield pair
-        java_cartesian._stage_input_is_pairs = True
-        return java_cartesian.flatMap(unpack_batches)
+        deserializer = CartesianDeserializer(self._jrdd_deserializer,
+                                             other._jrdd_deserializer)
+        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
 
     def groupBy(self, f, numPartitions=None):
         """
@@ -405,7 +411,7 @@ class RDD(object):
         self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
-            for item in read_from_pickle_file(tempFile):
+            for item in self._jrdd_deserializer.load_stream(tempFile):
                 yield item
         os.unlink(tempFile.name)
 
@@ -573,7 +579,7 @@ class RDD(object):
         items = []
         for partition in range(mapped._jrdd.splits().size()):
             iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
-            items.extend(self._collect_iterator_through_file(iterator))
+            items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
         return items[:num]
@@ -737,6 +743,7 @@ class RDD(object):
         # Transferring O(n) objects to Java is too expensive.  Instead, we'll
         # form the hash buckets in Python, transferring O(numPartitions) objects
         # to Java.  Each object is a (splitNumber, [objects]) pair.
+        outputSerializer = self.ctx._unbatched_serializer
         def add_shuffle_key(split, iterator):
 
             buckets = defaultdict(list)
@@ -745,14 +752,14 @@ class RDD(object):
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield pack_long(split)
-                yield dump_pickle(Batch(items))
+                yield outputSerializer._dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
         partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
                                                      id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
-        rdd = RDD(jrdd, self.ctx)
+        rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if
         # partitionFunc is a lambda:
         rdd._partitionFunc = partitionFunc
@@ -789,7 +796,8 @@ class RDD(object):
             numPartitions = self.ctx.defaultParallelism
         def combineLocally(iterator):
             combiners = {}
-            for (k, v) in iterator:
+            for x in iterator:
+                (k, v) = x
                 if k not in combiners:
                     combiners[k] = createCombiner(v)
                 else:
@@ -931,38 +939,38 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+        if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+            # This transformation is the first in its stage:
+            self.func = func
+            self.preservesPartitioning = preservesPartitioning
+            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd_deserializer = prev._jrdd_deserializer
+        else:
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
             self.func = pipeline_func
             self.preservesPartitioning = \
                 prev.preservesPartitioning and preservesPartitioning
-            self._prev_jrdd = prev._prev_jrdd
-        else:
-            self.func = func
-            self.preservesPartitioning = preservesPartitioning
-            self._prev_jrdd = prev._jrdd
-        self._stage_input_is_pairs = prev._stage_input_is_pairs
+            self._prev_jrdd = prev._prev_jrdd  # maintain the pipeline
+            self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
+        self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
 
     @property
     def _jrdd(self):
         if self._jrdd_val:
             return self._jrdd_val
-        func = self.func
-        if not self._bypass_serializer and self.ctx.batchSize != 1:
-            oldfunc = self.func
-            batchSize = self.ctx.batchSize
-            def batched_func(split, iterator):
-                return batched(oldfunc(split, iterator), batchSize)
-            func = batched_func
-        cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
+        if self._bypass_serializer:
+            serializer = NoOpSerializer()
+        else:
+            serializer = self.ctx.serializer
+        cmds = [self.func, self._prev_jrdd_deserializer, serializer]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fd02e1e..4fb4444 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,8 +15,58 @@
 # limitations under the License.
 #
 
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
 import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
 
 
 class SpecialLengths(object):
@@ -25,41 +75,206 @@ class SpecialLengths(object):
     TIMING_DATA = -3
 
 
-class Batch(object):
+class Serializer(object):
+
+    def dump_stream(self, iterator, stream):
+        """
+        Serialize an iterator of objects to the output stream.
+        """
+        raise NotImplementedError
+
+    def load_stream(self, stream):
+        """
+        Return an iterator of deserialized objects from the input stream.
+        """
+        raise NotImplementedError
+
+
+    def _load_stream_without_unbatching(self, stream):
+        return self.load_stream(stream)
+
+    # Note: our notion of "equality" is that output generated by
+    # equal serializers can be deserialized using the same serializer.
+
+    # This default implementation handles the simple cases;
+    # subclasses should override __eq__ as appropriate.
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+    """
+    Serializer that writes objects as a stream of (length, data) pairs,
+    where C{length} is a 32-bit integer and data is C{length} bytes.
+    """
+
+    def dump_stream(self, iterator, stream):
+        for obj in iterator:
+            self._write_with_length(obj, stream)
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._read_with_length(stream)
+            except EOFError:
+                return
+
+    def _write_with_length(self, obj, stream):
+        serialized = self._dumps(obj)
+        write_int(len(serialized), stream)
+        stream.write(serialized)
+
+    def _read_with_length(self, stream):
+        length = read_int(stream)
+        obj = stream.read(length)
+        if obj == "":
+            raise EOFError
+        return self._loads(obj)
+
+    def _dumps(self, obj):
+        """
+        Serialize an object into a byte array.
+        When batching is used, this will be called with an array of objects.
+        """
+        raise NotImplementedError
+
+    def _loads(self, obj):
+        """
+        Deserialize an object from a byte array.
+        """
+        raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+    """
+    Serializes a stream of objects in batches by calling its wrapped
+    Serializer with streams of objects.
+    """
+
+    UNLIMITED_BATCH_SIZE = -1
+
+    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+        self.serializer = serializer
+        self.batchSize = batchSize
+
+    def _batched(self, iterator):
+        if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+            yield list(iterator)
+        else:
+            items = []
+            count = 0
+            for item in iterator:
+                items.append(item)
+                count += 1
+                if count == self.batchSize:
+                    yield items
+                    items = []
+                    count = 0
+            if items:
+                yield items
+
+    def dump_stream(self, iterator, stream):
+        if isinstance(iterator, basestring):
+            iterator = [iterator]
+        self.serializer.dump_stream(self._batched(iterator), stream)
+
+    def load_stream(self, stream):
+        return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+    def _load_stream_without_unbatching(self, stream):
+            return self.serializer.load_stream(stream)
+
+    def __eq__(self, other):
+        return isinstance(other, BatchedSerializer) and \
+               other.serializer == self.serializer
+
+    def __str__(self):
+        return "BatchedSerializer<%s>" % str(self.serializer)
+
+
+class CartesianDeserializer(FramedSerializer):
     """
-    Used to store multiple RDD entries as a single Java object.
+    Deserializes the JavaRDD cartesian() of two PythonRDDs.
+    """
+
+    def __init__(self, key_ser, val_ser):
+        self.key_ser = key_ser
+        self.val_ser = val_ser
+
+    def load_stream(self, stream):
+        key_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_stream = self.val_ser._load_stream_without_unbatching(stream)
+        key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+        val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+        for (keys, vals) in izip(key_stream, val_stream):
+            keys = keys if key_is_batched else [keys]
+            vals = vals if val_is_batched else [vals]
+            for pair in product(keys, vals):
+                yield pair
+
+    def __eq__(self, other):
+        return isinstance(other, CartesianDeserializer) and \
+               self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+    def __str__(self):
+        return "CartesianDeserializer<%s, %s>" % \
+               (str(self.key_ser), str(self.val_ser))
+
+
+class NoOpSerializer(FramedSerializer):
+
+    def _loads(self, obj): return obj
+    def _dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
+    """
+    Serializes objects using Python's cPickle serializer:
+
+        http://docs.python.org/2/library/pickle.html
+
+    This serializer supports nearly any Python object, but may
+    not be as fast as more specialized serializers.
+    """
+
+    def _dumps(self, obj): return cPickle.dumps(obj, 2)
+    _loads = cPickle.loads
+
 
-    This relieves us from having to explicitly track whether an RDD
-    is stored as batches of objects and avoids problems when processing
-    the union() of batched and unbatched RDDs (e.g. the union() of textFile()
-    with another RDD).
+class MarshalSerializer(FramedSerializer):
     """
-    def __init__(self, items):
-        self.items = items
+    Serializes objects using Python's Marshal serializer:
 
+        http://docs.python.org/2/library/marshal.html
 
-def batched(iterator, batchSize):
-    if batchSize == -1: # unlimited batch size
-        yield Batch(list(iterator))
-    else:
-        items = []
-        count = 0
-        for item in iterator:
-            items.append(item)
-            count += 1
-            if count == batchSize:
-                yield Batch(items)
-                items = []
-                count = 0
-        if items:
-            yield Batch(items)
+    This serializer is faster than PickleSerializer but supports fewer datatypes.
+    """
+
+    _dumps = marshal.dumps
+    _loads = marshal.loads
 
 
-def dump_pickle(obj):
-    return cPickle.dumps(obj, 2)
+class MUTF8Deserializer(Serializer):
+    """
+    Deserializes streams written by Java's DataOutputStream.writeUTF().
+    """
 
+    def _loads(self, stream):
+        length = struct.unpack('>H', stream.read(2))[0]
+        return stream.read(length).decode('utf8')
 
-load_pickle = cPickle.loads
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._loads(stream)
+            except struct.error:
+                return
+            except EOFError:
+                return
 
 
 def read_long(stream):
@@ -90,43 +305,4 @@ def write_int(value, stream):
 
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
-    stream.write(obj)
-
-
-def read_mutf8(stream):
-    """
-    Read a string written with Java's DataOutputStream.writeUTF() method.
-    """
-    length = struct.unpack('>H', stream.read(2))[0]
-    return stream.read(length).decode('utf8')
-
-
-def read_with_length(stream):
-    length = read_int(stream)
-    obj = stream.read(length)
-    if obj == "":
-        raise EOFError
-    return obj
-
-
-def read_from_pickle_file(stream):
-    try:
-        while True:
-            obj = load_pickle(read_with_length(stream))
-            if type(obj) == Batch:  # We don't care about inheritance
-                for item in obj.items:
-                    yield item
-            else:
-                yield obj
-    except EOFError:
-        return
-
-
-def read_pairs_from_pickle_file(stream):
-    try:
-        while True:
-            a = load_pickle(read_with_length(stream))
-            b = load_pickle(read_with_length(stream))
-            yield (a, b)
-    except EOFError:
-        return
\ No newline at end of file
+    stream.write(obj)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a12..621e1cb 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
         time.sleep(1)  # 1 second
 
         self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
-        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+                                            flatMappedRDD._jrdd_deserializer)
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4e64557..5b16d5d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
-    SpecialLengths, read_mutf8, read_pairs_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+    write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
+
+
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
 
 
 def load_obj(infile):
-    return load_pickle(standard_b64decode(infile.readline().strip()))
+    decoded = standard_b64decode(infile.readline().strip())
+    return pickleSer._loads(decoded)
 
 
 def report_times(outfile, boot, init, finish):
@@ -53,7 +57,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = read_mutf8(infile)
+    spark_files_dir = mutf8_deserializer._loads(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -61,31 +65,24 @@ def main(infile, outfile):
     num_broadcast_variables = read_int(infile)
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
-        value = read_with_length(infile)
-        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+        value = pickleSer._read_with_length(infile)
+        _broadcastRegistry[bid] = Broadcast(bid, value)
 
     # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
+        filename = mutf8_deserializer._loads(infile)
+        sys.path.append(os.path.join(spark_files_dir, filename))
 
-    # now load function
+    # Load this stage's function and serializer:
     func = load_obj(infile)
-    bypassSerializer = load_obj(infile)
-    stageInputIsPairs = load_obj(infile)
-    if bypassSerializer:
-        dumps = lambda x: x
-    else:
-        dumps = dump_pickle
+    deserializer = load_obj(infile)
+    serializer = load_obj(infile)
     init_time = time.time()
-    if stageInputIsPairs:
-        iterator = read_pairs_from_pickle_file(infile)
-    else:
-        iterator = read_from_pickle_file(infile)
     try:
-        for obj in func(split_index, iterator):
-            write_with_length(dumps(obj), outfile)
+        iterator = deserializer.load_stream(infile)
+        serializer.dump_stream(func(split_index, iterator), outfile)
     except Exception as e:
         write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(traceback.format_exc(), outfile)
@@ -96,7 +93,7 @@ def main(infile, outfile):
     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
     write_int(len(_accumulatorRegistry), outfile)
     for (aid, accum) in _accumulatorRegistry.items():
-        write_with_length(dump_pickle((aid, accum._value)), outfile)
+        pickleSer._write_with_length((aid, accum._value), outfile)
 
 
 if __name__ == '__main__':

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/cbb7f04a/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index cbc554e..d4dad67 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
 run_test "pyspark/context.py"
 run_test "-m doctest pyspark/broadcast.py"
 run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
 run_test "pyspark/tests.py"
 
 if [[ $FAILED != 0 ]]; then


Mime
View raw message