spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-12340] Fix overflow in various take functions.
Date Sat, 09 Jan 2016 19:22:02 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3d77cffec -> b23c4521f


[SPARK-12340] Fix overflow in various take functions.

This is a follow-up for the original patch #10562.

Author: Reynold Xin <rxin@databricks.com>

Closes #10670 from rxin/SPARK-12340.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b23c4521
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b23c4521
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b23c4521

Branch: refs/heads/master
Commit: b23c4521f5df905e4fe4d79dd5b670286e2697f7
Parents: 3d77cff
Author: Reynold Xin <rxin@databricks.com>
Authored: Sat Jan 9 11:21:58 2016 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Sat Jan 9 11:21:58 2016 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/rdd/AsyncRDDActions.scala    |  8 ++++----
 core/src/main/scala/org/apache/spark/rdd/RDD.scala      |  4 ++--
 core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala |  4 ++++
 .../org/apache/spark/sql/execution/SparkPlan.scala      |  7 +++----
 .../scala/org/apache/spark/sql/DataFrameSuite.scala     |  6 ++++++
 .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ------------
 6 files changed, 19 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b23c4521/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 94719a4..7de9df1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with
Loggi
       This implementation is non-blocking, asynchronously handling the
       results of each job and triggering the next job using callbacks on futures.
      */
-    def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]]
=
+    def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]]
=
       if (results.size >= num || partsScanned >= totalParts) {
         Future.successful(results.toSeq)
       } else {
@@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with
Loggi
         }
 
         val left = num - results.size
-        val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+        val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
 
         val buf = new Array[Array[T]](p.size)
         self.context.setCallSite(callSite)
@@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable
with Loggi
           p,
           (index: Int, data: Array[T]) => buf(index) = data,
           Unit)
-        job.flatMap {_ =>
+        job.flatMap { _ =>
           buf.foreach(results ++= _.take(num - results.size))
           continue(partsScanned + p.size)
         }
       }
 
-    new ComplexFutureAction[Seq[T]](continue(0L)(_))
+    new ComplexFutureAction[Seq[T]](continue(0)(_))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b23c4521/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e25657c..de7102f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag](
     } else {
       val buf = new ArrayBuffer[T]
       val totalParts = this.partitions.length
-      var partsScanned = 0L
+      var partsScanned = 0
       while (buf.size < num && partsScanned < totalParts) {
         // The number of partitions to try in this iteration. It is ok for this number to
be
         // greater than totalParts because we actually cap it at totalParts in runJob.
@@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag](
         }
 
         val left = num - buf.size
-        val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+        val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
         val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
 
         res.foreach(buf ++= _.take(num - buf.size))

http://git-wip-us.apache.org/repos/asf/spark/blob/b23c4521/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 24acbed..ef2ed44 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
     assert(nums.take(501) === (1 to 501).toArray)
     assert(nums.take(999) === (1 to 999).toArray)
     assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.parallelize(1 to 2, 2)
+    assert(nums.take(2147483638).size === 2)
+    assert(nums.takeAsync(2147483638).get.size === 2)
   }
 
   test("top with predefined ordering") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b23c4521/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21a6fba..2355de3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -165,7 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with
Serializ
 
     val buf = new ArrayBuffer[InternalRow]
     val totalParts = childRDD.partitions.length
-    var partsScanned = 0L
+    var partsScanned = 0
     while (buf.size < n && partsScanned < totalParts) {
       // The number of partitions to try in this iteration. It is ok for this number to be
       // greater than totalParts because we actually cap it at totalParts in runJob.
@@ -183,10 +183,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with
Serializ
       numPartsToTry = math.max(0, numPartsToTry)  // guard against negative num of partitions
 
       val left = n - buf.size
-      val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+      val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
       val sc = sqlContext.sparkContext
-      val res =
-        sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
+      val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray,
p)
 
       res.foreach(buf ++= _.take(n - buf.size))
       partsScanned += p.size

http://git-wip-us.apache.org/repos/asf/spark/blob/b23c4521/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ade1391..983dfbd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -308,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     checkAnswer(
       mapData.toDF().limit(1),
       mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
+
+    // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
+    checkAnswer(
+      sqlContext.range(2).limit(2147483638),
+      Row(0) :: Row(1) :: Nil
+    )
   }
 
   test("except") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b23c4521/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bd987ae..5de0979 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2067,16 +2067,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       )
     }
   }
-
-  test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
-    val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
-    rdd.toDF("key").registerTempTable("spark12340")
-    checkAnswer(
-      sql("select key from spark12340 limit 2147483638"),
-      Row(1) :: Row(2) :: Row(3) :: Nil
-    )
-    assert(rdd.take(2147483638).size === 3)
-    assert(rdd.takeAsync(2147483638).get.size === 3)
-  }
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message