spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-4327] [PySpark] Python API for RDD.randomSplit()
Date Wed, 19 Nov 2014 00:37:39 GMT
Repository: spark
Updated Branches:
  refs/heads/master bb4604615 -> 7f22fa81e


[SPARK-4327] [PySpark] Python API for RDD.randomSplit()

```
pyspark.RDD.randomSplit(self, weights, seed=None)
    Randomly splits this RDD with the provided weights.

    :param weights: weights for splits, will be normalized if they don't sum to 1
    :param seed: random seed
    :return: split RDDs in an list

    >>> rdd = sc.parallelize(range(10), 1)
    >>> rdd1, rdd2, rdd3 = rdd.randomSplit([0.4, 0.6, 1.0], 11)
    >>> rdd1.collect()
    [3, 6]
    >>> rdd2.collect()
    [0, 5, 7]
    >>> rdd3.collect()
    [1, 2, 4, 8, 9]
```

Author: Davies Liu <davies@databricks.com>

Closes #3193 from davies/randomSplit and squashes the following commits:

78bf997 [Davies Liu] fix tests, do not use numpy in randomSplit, no performance gain
f5fdf63 [Davies Liu] fix bug with int in weights
4dfa2cd [Davies Liu] refactor
f866bcf [Davies Liu] remove unneeded change
c7a2007 [Davies Liu] switch to python implementation
95a48ac [Davies Liu] Merge branch 'master' of github.com:apache/spark into randomSplit
0d9b256 [Davies Liu] refactor
1715ee3 [Davies Liu] address comments
41fce54 [Davies Liu] randomSplit()


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f22fa81
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f22fa81
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f22fa81

Branch: refs/heads/master
Commit: 7f22fa81ebd5e501fcb0e1da5506d1d4fb9250cf
Parents: bb46046
Author: Davies Liu <davies@databricks.com>
Authored: Tue Nov 18 16:37:35 2014 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Nov 18 16:37:35 2014 -0800

----------------------------------------------------------------------
 python/pyspark/rdd.py        | 30 +++++++++++++++++++++++++++---
 python/pyspark/rddsampler.py | 14 ++++++++++++++
 2 files changed, 41 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f22fa81/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 08d0474..50535d2 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -28,7 +28,7 @@ from threading import Thread
 import warnings
 import heapq
 import bisect
-from random import Random
+import random
 from math import sqrt, log, isinf, isnan
 
 from pyspark.accumulators import PStatsParam
@@ -38,7 +38,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_full_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
 from pyspark.storagelevel import StorageLevel
 from pyspark.resultiterable import ResultIterable
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -316,6 +316,30 @@ class RDD(object):
         assert fraction >= 0.0, "Negative fraction value: %s" % fraction
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func,
True)
 
+    def randomSplit(self, weights, seed=None):
+        """
+        Randomly splits this RDD with the provided weights.
+
+        :param weights: weights for splits, will be normalized if they don't sum to 1
+        :param seed: random seed
+        :return: split RDDs in a list
+
+        >>> rdd = sc.parallelize(range(5), 1)
+        >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
+        >>> rdd1.collect()
+        [1, 3]
+        >>> rdd2.collect()
+        [0, 2, 4]
+        """
+        s = float(sum(weights))
+        cweights = [0.0]
+        for w in weights:
+            cweights.append(cweights[-1] + w / s)
+        if seed is None:
+            seed = random.randint(0, 2 ** 32 - 1)
+        return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
+                for lb, ub in zip(cweights, cweights[1:])]
+
     # this is ported from scala/spark/RDD.scala
     def takeSample(self, withReplacement, num, seed=None):
         """
@@ -341,7 +365,7 @@ class RDD(object):
         if initialCount == 0:
             return []
 
-        rand = Random(seed)
+        rand = random.Random(seed)
 
         if (not withReplacement) and num >= initialCount:
             # shuffle current RDD and return

http://git-wip-us.apache.org/repos/asf/spark/blob/7f22fa81/python/pyspark/rddsampler.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index f5c3cfd..558dcfd 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -115,6 +115,20 @@ class RDDSampler(RDDSamplerBase):
                     yield obj
 
 
+class RDDRangeSampler(RDDSamplerBase):
+
+    def __init__(self, lowerBound, upperBound, seed=None):
+        RDDSamplerBase.__init__(self, False, seed)
+        self._use_numpy = False  # no performance gain from numpy
+        self._lowerBound = lowerBound
+        self._upperBound = upperBound
+
+    def func(self, split, iterator):
+        for obj in iterator:
+            if self._lowerBound <= self.getUniformSample(split) < self._upperBound:
+                yield obj
+
+
 class RDDStratifiedSampler(RDDSamplerBase):
 
     def __init__(self, withReplacement, fractions, seed=None):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message