spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject git commit: SPARK-1240: handle the case of empty RDD when takeSample
Date Mon, 17 Mar 2014 05:41:10 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-0.9 1dc1e988f -> af7e8b1c9


SPARK-1240: handle the case of empty RDD when takeSample

https://spark-project.atlassian.net/browse/SPARK-1240

It seems that the current implementation does not handle the empty RDD case when run takeSample

In this patch, before calling sample() inside takeSample API, I add a checker for this case
and returns an empty Array when it's a empty RDD; also in sample(), I add a checker for the
invalid fraction value

In the test case, I also add several lines for this case

Author: CodingCat <zhunansjtu@gmail.com>

Closes #135 from CodingCat/SPARK-1240 and squashes the following commits:

fef57d4 [CodingCat] fix the same problem in PySpark
36db06b [CodingCat] create new test cases for takeSample from an empty red
810948d [CodingCat] further fix
a40e8fb [CodingCat] replace if with require
ad483fd [CodingCat] handle the case with empty RDD when take sample

Conflicts:
	core/src/main/scala/org/apache/spark/rdd/RDD.scala
	core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af7e8b1c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af7e8b1c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af7e8b1c

Branch: refs/heads/branch-0.9
Commit: af7e8b1c9913ebb6d4131a03cfe0cd0a2b38c529
Parents: 1dc1e98
Author: CodingCat <zhunansjtu@gmail.com>
Authored: Sun Mar 16 22:14:59 2014 -0700
Committer: Matei Zaharia <matei@databricks.com>
Committed: Sun Mar 16 22:40:22 2014 -0700

----------------------------------------------------------------------
 core/src/main/scala/org/apache/spark/rdd/RDD.scala      | 10 ++++++++--
 core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala |  7 +++++++
 python/pyspark/rdd.py                                   |  4 ++++
 3 files changed, 19 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af7e8b1c/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1472c92..b529754 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -319,8 +319,10 @@ abstract class RDD[T: ClassTag](
   /**
    * Return a sampled subset of this RDD.
    */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
+  def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
+    require(fraction >= 0.0, "Invalid fraction value: " + fraction)
     new SampledRDD(this, withReplacement, fraction, seed)
+  }
 
   def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
     var fraction = 0.0
@@ -333,6 +335,10 @@ abstract class RDD[T: ClassTag](
       throw new IllegalArgumentException("Negative number of elements requested")
     }
 
+    if (initialCount == 0) {
+      return new Array[T](0)
+    }
+
     if (initialCount > Integer.MAX_VALUE - 1) {
       maxSelected = Integer.MAX_VALUE - 1
     } else {
@@ -351,7 +357,7 @@ abstract class RDD[T: ClassTag](
     var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
 
     // If the first sample didn't turn out large enough, keep trying to take samples;
-    // this shouldn't happen often because we use a big multiplier for thei initial size
+    // this shouldn't happen often because we use a big multiplier for the initial size
     while (samples.length < total) {
       samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/af7e8b1c/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 559ea05..ac9df34 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -455,6 +455,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
 
   test("takeSample") {
     val data = sc.parallelize(1 to 100, 2)
+
     for (seed <- 1 to 5) {
       val sample = data.takeSample(withReplacement=false, 20, seed)
       assert(sample.size === 20)        // Got exactly 20 elements
@@ -486,6 +487,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     }
   }
 
+  test("takeSample from an empty rdd") {
+    val emptySet = sc.parallelize(Seq.empty[Int], 2)
+    val sample = emptySet.takeSample(false, 20, 1)
+    assert(sample.length === 0)
+  }
+
   test("runJob on an invalid partition") {
     intercept[IllegalArgumentException] {
       sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0,
1, 2), false)

http://git-wip-us.apache.org/repos/asf/spark/blob/af7e8b1c/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c29cefa..6d05ff2 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -254,6 +254,7 @@ class RDD(object):
         >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest:
+SKIP
         [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
         """
+        assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func,
True)
 
     # this is ported from scala/spark/RDD.scala
@@ -274,6 +275,9 @@ class RDD(object):
         if (num < 0):
             raise ValueError
 
+        if (initialCount == 0):
+            return list()
+
         if initialCount > sys.maxint - 1:
             maxSelected = sys.maxint - 1
         else:


Mime
View raw message