spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-12662][SQL] Fix DataFrame.randomSplit to avoid creating overlapping splits
Date Thu, 07 Jan 2016 18:37:18 GMT
Repository: spark
Updated Branches:
  refs/heads/master 592f64985 -> f194d9911


[SPARK-12662][SQL] Fix DataFrame.randomSplit to avoid creating overlapping splits

https://issues.apache.org/jira/browse/SPARK-12662

cc yhuai

Author: Sameer Agarwal <sameer@databricks.com>

Closes #10626 from sameeragarwal/randomsplit.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f194d991
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f194d991
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f194d991

Branch: refs/heads/master
Commit: f194d9911a93fc3a78be820096d4836f22d09976
Parents: 592f649
Author: Sameer Agarwal <sameer@databricks.com>
Authored: Thu Jan 7 10:37:15 2016 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Thu Jan 7 10:37:15 2016 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/DataFrame.scala  |  7 ++++++-
 .../apache/spark/sql/DataFrameStatSuite.scala   | 22 ++++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f194d991/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 7cf2818..60d2f05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1062,10 +1062,15 @@ class DataFrame private[sql](
    * @since 1.4.0
    */
   def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+    // It is possible that the underlying dataframe doesn't guarantee the ordering of rows
in its
+    // constituent partitions each time a split is materialized which could result in
+    // overlapping splits. To prevent this, we explicitly sort each input partition to make
the
+    // ordering deterministic.
+    val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
     val sum = weights.sum
     val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
     normalizedCumWeights.sliding(2).map { x =>
-      new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan))
+      new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted))
     }.toArray
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f194d991/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index b15af42..63ad6c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("randomSplit on reordered partitions") {
+    // This test ensures that randomSplit does not create overlapping splits even when the
+    // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering
of
+    // rows in each partition.
+    val data =
+      sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
+    val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
+
+    assert(splits.length == 2, "wrong number of splits")
+
+    // Verify that the splits span the entire dataset
+    assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
+
+    // Verify that the splits don't overalap
+    assert(splits(0).intersect(splits(1)).collect().isEmpty)
+
+    // Verify that the results are deterministic across multiple runs
+    val firstRun = splits.toSeq.map(_.collect().toSeq)
+    val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
+    assert(firstRun == secondRun)
+  }
+
   test("pearson correlation") {
     val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
     val corr1 = df.stat.corr("a", "b", "pearson")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message