spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject git commit: Spark 1162 Implemented takeOrdered in pyspark.
Date Thu, 03 Apr 2014 22:42:41 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-0.9 a6c955a1a -> 28e764379


Spark 1162 Implemented takeOrdered in pyspark.

Since python does not have a library for max heap and usual tricks like inverting values etc..
does not work for all cases.

We have our own implementation of max heap.

Author: Prashant Sharma <prashant.s@imaginea.com>

Closes #97 from ScrapCodes/SPARK-1162/pyspark-top-takeOrdered2 and squashes the following
commits:

35f86ba [Prashant Sharma] code review
2b1124d [Prashant Sharma] fixed tests
e8a08e2 [Prashant Sharma] Code review comments.
49e6ba7 [Prashant Sharma] SPARK-1162 added takeOrdered to pyspark

(cherry picked from commit c1ea3afb516c204925259f0928dfb17d0fa89621)
Signed-off-by: Matei Zaharia <matei@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/28e76437
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/28e76437
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/28e76437

Branch: refs/heads/branch-0.9
Commit: 28e7643797f3b18fb8351899833035f7141d4792
Parents: a6c955a
Author: Prashant Sharma <prashant.s@imaginea.com>
Authored: Thu Apr 3 15:42:17 2014 -0700
Committer: Matei Zaharia <matei@databricks.com>
Committed: Thu Apr 3 15:42:31 2014 -0700

----------------------------------------------------------------------
 python/pyspark/rdd.py | 107 ++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 102 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/28e76437/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index deaf896..ace8476 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -28,7 +28,7 @@ from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
 import warnings
-from heapq import heappush, heappop, heappushpop
+import heapq
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -40,9 +40,9 @@ from pyspark.storagelevel import StorageLevel
 
 from py4j.java_collections import ListConverter, MapConverter
 
-
 __all__ = ["RDD"]
 
+
 def _extract_concise_traceback():
     tb = traceback.extract_stack()
     if len(tb) == 0:
@@ -84,6 +84,73 @@ class _JavaStackTrace(object):
         if _spark_stack_depth == 0:
             self._context._jsc.setCallSite(None)
 
+class MaxHeapQ(object):
+    """
+    An implementation of MaxHeap.
+    >>> import pyspark.rdd
+    >>> heap = pyspark.rdd.MaxHeapQ(5)
+    >>> [heap.insert(i) for i in range(10)]
+    [None, None, None, None, None, None, None, None, None, None]
+    >>> sorted(heap.getElements())
+    [0, 1, 2, 3, 4]
+    >>> heap = pyspark.rdd.MaxHeapQ(5)
+    >>> [heap.insert(i) for i in range(9, -1, -1)]
+    [None, None, None, None, None, None, None, None, None, None]
+    >>> sorted(heap.getElements())
+    [0, 1, 2, 3, 4]
+    >>> heap = pyspark.rdd.MaxHeapQ(1)
+    >>> [heap.insert(i) for i in range(9, -1, -1)]
+    [None, None, None, None, None, None, None, None, None, None]
+    >>> heap.getElements()
+    [0]
+    """
+
+    def __init__(self, maxsize):
+        # we start from q[1], this makes calculating children as trivial as 2 * k
+        self.q = [0]
+        self.maxsize = maxsize
+
+    def _swim(self, k):
+        while (k > 1) and (self.q[k/2] < self.q[k]):
+            self._swap(k, k/2)
+            k = k/2
+
+    def _swap(self, i, j):
+        t = self.q[i]
+        self.q[i] = self.q[j]
+        self.q[j] = t
+
+    def _sink(self, k):
+        N = self.size()
+        while 2 * k <= N:
+            j = 2 * k
+            # Here we test if both children are greater than parent
+            # if not swap with larger one.
+            if j < N and self.q[j] < self.q[j + 1]:
+                j = j + 1
+            if(self.q[k] > self.q[j]):
+                break
+            self._swap(k, j)
+            k = j
+
+    def size(self):
+        return len(self.q) - 1
+
+    def insert(self, value):
+        if (self.size()) < self.maxsize:
+            self.q.append(value)
+            self._swim(self.size())
+        else:
+            self._replaceRoot(value)
+
+    def getElements(self):
+        return self.q[1:]
+
+    def _replaceRoot(self, value):
+        if(self.q[1] > value):
+            self.q[1] = value
+            self._sink(1)
+
 class RDD(object):
     """
     A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -630,16 +697,16 @@ class RDD(object):
         Note: It returns the list sorted in descending order.
         >>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
         [12]
-        >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
+        >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
         [6, 5]
         """
         def topIterator(iterator):
             q = []
             for k in iterator:
                 if len(q) < num:
-                    heappush(q, k)
+                    heapq.heappush(q, k)
                 else:
-                    heappushpop(q, k)
+                    heapq.heappushpop(q, k)
             yield q
 
         def merge(a, b):
@@ -647,6 +714,36 @@ class RDD(object):
 
         return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
 
+    def takeOrdered(self, num, key=None):
+        """
+        Get the N elements from a RDD ordered in ascending order or as specified
+        by the optional key function. 
+
+        >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
+        [1, 2, 3, 4, 5, 6]
+        >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda
x: -x)
+        [10, 9, 7, 6, 5, 4]
+        """
+
+        def topNKeyedElems(iterator, key_=None):
+            q = MaxHeapQ(num)
+            for k in iterator:
+                if key_ != None:
+                    k = (key_(k), k)
+                q.insert(k)
+            yield q.getElements()
+
+        def unKey(x, key_=None):
+            if key_ != None:
+                x = [i[1] for i in x]
+            return x
+        
+        def merge(a, b):
+            return next(topNKeyedElems(a + b))
+        result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
+        return sorted(unKey(result, key), key=key)
+
+
     def take(self, num):
         """
         Take the first num elements of the RDD.


Mime
View raw message