spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From samee...@apache.org
Subject spark git commit: [SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers
Date Fri, 26 Jan 2018 23:01:09 GMT
Repository: spark
Updated Branches:
  refs/heads/master a8a3e9b7c -> 94c67a76e


[SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers

## What changes were proposed in this pull request?

Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic
since the sequence of input rows are not determined.

The bug can be triggered when there is a repartition call following a shuffle (which would
lead to non-deterministic row ordering), as the pattern shows below:
upstream stage -> repartition stage -> result stage
(-> indicate a shuffle)
When one of the executors process goes down, some tasks on the repartition stage will be retried
and generate inconsistent ordering, and some tasks of the result stage will be retried generating
different data.

The following code returns 931532, instead of 1000000:
```
import scala.sys.process._

import org.apache.spark.TaskContext
val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x =>
  x
}.repartition(200).map { x =>
  if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) {
    throw new Exception("pkill -f java".!!)
  }
  x
}
res.distinct().count()
```

In this PR, we propose a most straight-forward way to fix this problem by performing a local
sort before partitioning, after we make the input row ordering deterministic, the function
from rows to partitions is fully deterministic too.

The downside of the approach is that with extra local sort inserted, the performance of repartition()
will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to
control whether this patch is applied. The patch is default enabled to be safe-by-default,
but user may choose to manually turn it off to avoid performance regression.

This patch also changes the output rows ordering of repartition(), that leads to a bunch of
test cases failure because they are comparing the results directly.

## How was this patch tested?

Add unit test in ExchangeSuite.

With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following
query returns 1000000:
```
import scala.sys.process._

import org.apache.spark.TaskContext

spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true")

val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x =>
  x
}.repartition(200).map { x =>
  if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) {
    throw new Exception("pkill -f java".!!)
  }
  x
}
res.distinct().count()

res7: Long = 1000000
```

Author: Xingbo Jiang <xingbo.jiang@databricks.com>

Closes #20393 from jiangxb1987/shuffle-repartition.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94c67a76
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94c67a76
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94c67a76

Branch: refs/heads/master
Commit: 94c67a76ec1fda908a671a47a2a1fa63b3ab1b06
Parents: a8a3e9b
Author: Xingbo Jiang <xingbo.jiang@databricks.com>
Authored: Fri Jan 26 15:01:03 2018 -0800
Committer: Sameer Agarwal <sameerag@apache.org>
Committed: Fri Jan 26 15:01:03 2018 -0800

----------------------------------------------------------------------
 .../unsafe/sort/RecordComparator.java           |  4 +-
 .../unsafe/sort/UnsafeInMemorySorter.java       |  7 +-
 .../unsafe/sort/UnsafeSorterSpillMerger.java    |  4 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  2 +
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  4 +-
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  |  8 ++-
 .../apache/spark/ml/feature/Word2VecSuite.scala |  3 +-
 .../sql/execution/RecordBinaryComparator.java   | 70 ++++++++++++++++++++
 .../sql/execution/UnsafeExternalRowSorter.java  | 44 ++++++++++--
 .../org/apache/spark/sql/internal/SQLConf.scala | 14 ++++
 .../sql/execution/UnsafeKVExternalSorter.java   |  8 ++-
 .../apache/spark/sql/execution/SortExec.scala   |  2 +-
 .../exchange/ShuffleExchangeExec.scala          | 52 ++++++++++++++-
 .../spark/sql/execution/ExchangeSuite.scala     | 26 +++++++-
 .../datasources/parquet/ParquetIOSuite.scala    |  6 +-
 .../datasources/text/WholeTextFileSuite.scala   |  2 +-
 .../execution/streaming/ForeachSinkSuite.scala  |  6 +-
 17 files changed, 233 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
index 09e4258..02b5de8 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
   public abstract int compare(
     Object leftBaseObject,
     long leftBaseOffset,
+    int leftBaseLength,
     Object rightBaseObject,
-    long rightBaseOffset);
+    long rightBaseOffset,
+    int rightBaseLength);
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 951d076..b3c27d8 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -62,12 +62,13 @@ public final class UnsafeInMemorySorter {
       int uaoSize = UnsafeAlignedOffset.getUaoSize();
       if (prefixComparisonResult == 0) {
         final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
-        // skip length
         final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
+        final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize);
         final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
-        // skip length
         final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
-        return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
+        final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize);
+        return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2,
+          baseOffset2, baseLength2);
       } else {
         return prefixComparisonResult;
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index cf4dfde..ff0dcc2 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger {
         prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
       if (prefixComparisonResult == 0) {
         return recordComparator.compare(
-          left.getBaseObject(), left.getBaseOffset(),
-          right.getBaseObject(), right.getBaseOffset());
+          left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
+          right.getBaseObject(), right.getBaseOffset(), right.getRecordLength());
       } else {
         return prefixComparisonResult;
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 7859781..0574abd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -414,6 +414,8 @@ abstract class RDD[T: ClassTag](
    *
    * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
    * which can avoid performing a shuffle.
+   *
+   * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
    */
   def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope
{
     coalesce(numPartitions, shuffle = true)

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index af4975c..411cd5c 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite {
     public int compare(
       Object leftBaseObject,
       long leftBaseOffset,
+      int leftBaseLength,
       Object rightBaseObject,
-      long rightBaseOffset) {
+      long rightBaseOffset,
+      int rightBaseLength) {
       return 0;
     }
   };

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 594f07d..c145532 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -98,8 +98,10 @@ public class UnsafeInMemorySorterSuite {
       public int compare(
         Object leftBaseObject,
         long leftBaseOffset,
+        int leftBaseLength,
         Object rightBaseObject,
-        long rightBaseOffset) {
+        long rightBaseOffset,
+        int rightBaseLength) {
         return 0;
       }
     };
@@ -164,8 +166,10 @@ public class UnsafeInMemorySorterSuite {
       public int compare(
               Object leftBaseObject,
               long leftBaseOffset,
+              int leftBaseLength,
               Object rightBaseObject,
-              long rightBaseOffset) {
+              long rightBaseOffset,
+              int rightBaseLength) {
         return 0;
       }
     };

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 6183606..10682ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
     val oldModel = new OldWord2VecModel(word2VecMap)
     val instance = new Word2VecModel("myWord2VecModel", oldModel)
     val newInstance = testDefaultReadWrite(instance)
-    assert(newInstance.getVectors.collect() === instance.getVectors.collect())
+    assert(newInstance.getVectors.collect().sortBy(_.getString(0)) ===
+      instance.getVectors.collect().sortBy(_.getString(0)))
   }
 
   test("Word2Vec works with input that is non-nullable (NGram)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
new file mode 100644
index 0000000..bb77b5b
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+
+public final class RecordBinaryComparator extends RecordComparator {
+
+  // TODO(jiangxb) Add test suite for this.
+  @Override
+  public int compare(
+      Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen)
{
+    int i = 0;
+    int res = 0;
+
+    // If the arrays have different length, the longer one is larger.
+    if (leftLen != rightLen) {
+      return leftLen - rightLen;
+    }
+
+    // The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`,
since
+    // we have guaranteed `leftLen` == `rightLen`.
+
+    // check if stars align and we can get both offsets to be aligned
+    if ((leftOff % 8) == (rightOff % 8)) {
+      while ((leftOff + i) % 8 != 0 && i < leftLen) {
+        res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
+                (Platform.getByte(rightObj, rightOff + i) & 0xff);
+        if (res != 0) return res;
+        i += 1;
+      }
+    }
+    // for architectures that support unaligned accesses, chew it up 8 bytes at a time
+    if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8
== 0))) {
+      while (i <= leftLen - 8) {
+        res = (int) ((Platform.getLong(leftObj, leftOff + i) -
+                Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE);
+        if (res != 0) return res;
+        i += 8;
+      }
+    }
+    // this will finish off the unaligned comparisons, or do the entire aligned comparison
+    // whichever is needed.
+    while (i < leftLen) {
+      res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
+              (Platform.getByte(rightObj, rightOff + i) & 0xff);
+      if (res != 0) return res;
+      i += 1;
+    }
+
+    // The two arrays are equal.
+    return 0;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 6b002f0..78647b5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -18,7 +18,9 @@
 package org.apache.spark.sql.execution;
 
 import java.io.IOException;
+import java.util.function.Supplier;
 
+import org.apache.spark.sql.catalyst.util.TypeUtils;
 import scala.collection.AbstractIterator;
 import scala.collection.Iterator;
 import scala.math.Ordering;
@@ -56,26 +58,50 @@ public final class UnsafeExternalRowSorter {
 
     public static class Prefix {
       /** Key prefix value, or the null prefix value if isNull = true. **/
-      long value;
+      public long value;
 
       /** Whether the key is null. */
-      boolean isNull;
+      public boolean isNull;
     }
 
     /**
      * Computes prefix for the given row. For efficiency, the returned object may be reused
in
      * further calls to a given PrefixComputer.
      */
-    abstract Prefix computePrefix(InternalRow row);
+    public abstract Prefix computePrefix(InternalRow row);
   }
 
-  public UnsafeExternalRowSorter(
+  public static UnsafeExternalRowSorter createWithRecordComparator(
+      StructType schema,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      PrefixComputer prefixComputer,
+      long pageSizeBytes,
+      boolean canUseRadixSort) throws IOException {
+    return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
+      prefixComputer, pageSizeBytes, canUseRadixSort);
+  }
+
+  public static UnsafeExternalRowSorter create(
       StructType schema,
       Ordering<InternalRow> ordering,
       PrefixComparator prefixComparator,
       PrefixComputer prefixComputer,
       long pageSizeBytes,
       boolean canUseRadixSort) throws IOException {
+    Supplier<RecordComparator> recordComparatorSupplier =
+      () -> new RowComparator(ordering, schema.length());
+    return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
+      prefixComputer, pageSizeBytes, canUseRadixSort);
+  }
+
+  private UnsafeExternalRowSorter(
+      StructType schema,
+      Supplier<RecordComparator> recordComparatorSupplier,
+      PrefixComparator prefixComparator,
+      PrefixComputer prefixComputer,
+      long pageSizeBytes,
+      boolean canUseRadixSort) throws IOException {
     this.schema = schema;
     this.prefixComputer = prefixComputer;
     final SparkEnv sparkEnv = SparkEnv.get();
@@ -85,7 +111,7 @@ public final class UnsafeExternalRowSorter {
       sparkEnv.blockManager(),
       sparkEnv.serializerManager(),
       taskContext,
-      () -> new RowComparator(ordering, schema.length()),
+      recordComparatorSupplier,
       prefixComparator,
       sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
                              DEFAULT_INITIAL_SORT_BUFFER_SIZE),
@@ -206,7 +232,13 @@ public final class UnsafeExternalRowSorter {
     }
 
     @Override
-    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+    public int compare(
+        Object baseObj1,
+        long baseOff1,
+        int baseLen1,
+        Object baseObj2,
+        long baseOff2,
+        int baseLen2) {
       // Note that since ordering doesn't need the total length of the record, we just pass
0
       // into the row.
       row1.pointTo(baseObj1, baseOff1, 0);

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b0d18b6..76b9d6f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1145,6 +1145,18 @@ object SQLConf {
       .checkValues(PartitionOverwriteMode.values.map(_.toString))
       .createWithDefault(PartitionOverwriteMode.STATIC.toString)
 
+  val SORT_BEFORE_REPARTITION =
+    buildConf("spark.sql.execution.sortBeforeRepartition")
+      .internal()
+      .doc("When perform a repartition following a shuffle, the output row ordering would
be " +
+        "nondeterministic. If some downstream stages fail and some tasks of the repartition
" +
+        "stage retry, these tasks may generate different data, and that can lead to correctness
" +
+        "issues. Turn on this config to insert a local sort before actually doing repartition
" +
+        "to generate consistent repartition results. The performance of repartition() may
go " +
+        "down since we insert extra local sort before it.")
+      .booleanConf
+      .createWithDefault(true)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -1300,6 +1312,8 @@ class SQLConf extends Serializable with Logging {
 
   def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)
 
+  def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
+
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used to determine
if two
    * identifiers are equal.

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index eb2fe82..b0b5383 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -241,7 +241,13 @@ public final class UnsafeKVExternalSorter {
     }
 
     @Override
-    public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+    public int compare(
+        Object baseObj1,
+        long baseOff1,
+        int baseLen1,
+        Object baseObj2,
+        long baseOff2,
+        int baseLen2) {
       // Note that since ordering doesn't need the total length of the record, we just pass
0
       // into the row.
       row1.pointTo(baseObj1, baseOff1 + 4, 0);

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index ef1bb1c..ac1c34d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -84,7 +84,7 @@ case class SortExec(
     }
 
     val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
-    val sorter = new UnsafeExternalRowSorter(
+    val sorter = UnsafeExternalRowSorter.create(
       schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)
 
     if (testSpillFrequency > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 5a1e217..76c1fa6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.exchange
 
 import java.util.Random
+import java.util.function.Supplier
 
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
@@ -25,13 +26,15 @@ import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.MutablePair
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
 
 /**
  * Performs a shuffle that will result in the desired `newPartitioning`.
@@ -247,14 +250,57 @@ object ShuffleExchangeExec {
       case RangePartitioning(_, _) | SinglePartition => identity
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
     }
+
     val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
-      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+      // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
+      // otherwise a retry task may output different rows and thus lead to data loss.
+      //
+      // Currently we following the most straight-forward way that perform a local sort before
+      // partitioning.
+      val newRdd = if (SQLConf.get.sortBeforeRepartition &&
+          newPartitioning.isInstanceOf[RoundRobinPartitioning]) {
         rdd.mapPartitionsInternal { iter =>
+          val recordComparatorSupplier = new Supplier[RecordComparator] {
+            override def get: RecordComparator = new RecordBinaryComparator()
+          }
+          // The comparator for comparing row hashcode, which should always be Integer.
+          val prefixComparator = PrefixComparators.LONG
+          val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)
+          // The prefix computer generates row hashcode as the prefix, so we may decrease
the
+          // probability that the prefixes are equal when input rows choose column values
from a
+          // limited range.
+          val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+            private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+            override def computePrefix(row: InternalRow):
+            UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+              // The hashcode generated from the binary form of a [[UnsafeRow]] should not
be null.
+              result.isNull = false
+              result.value = row.hashCode()
+              result
+            }
+          }
+          val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+
+          val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
+            StructType.fromAttributes(outputAttributes),
+            recordComparatorSupplier,
+            prefixComparator,
+            prefixComputer,
+            pageSize,
+            canUseRadixSort)
+          sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+        }
+      } else {
+        rdd
+      }
+
+      if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+        newRdd.mapPartitionsInternal { iter =>
           val getPartitionKey = getPartitionKeyExtractor()
           iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
         }
       } else {
-        rdd.mapPartitionsInternal { iter =>
+        newRdd.mapPartitionsInternal { iter =>
           val getPartitionKey = getPartitionKeyExtractor()
           val mutablePair = new MutablePair[Int, InternalRow]()
           iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)),
row) }

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index aac8d56..697d7e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.Row
+import scala.util.Random
+
+import org.apache.spark.sql.{Dataset, Row}
 import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode,
SinglePartition}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec,
ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
     assert(exchange4.sameResult(exchange5))
     assert(exchange5 sameResult exchange4)
   }
+
+  test("SPARK-23207: Make repartition() generate consistent output") {
+    def assertConsistency(ds: Dataset[java.lang.Long]): Unit = {
+      ds.persist()
+
+      val exchange = ds.mapPartitions { iter =>
+        Random.shuffle(iter)
+      }.repartition(111)
+      val exchange2 = ds.repartition(111)
+
+      assert(exchange.rdd.collectPartitions() === exchange2.rdd.collectPartitions())
+    }
+
+    withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") {
+      // repartition() should generate consistent output.
+      assertConsistency(spark.range(10000))
+
+      // case when input contains duplicated rows.
+      assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong))
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 44a8b25..f3ece5b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -662,7 +662,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext
{
             val v = (row.getInt(0), row.getString(1))
             result += v
           }
-          assert(data == result)
+          assert(data.toSet == result.toSet)
         } finally {
           reader.close()
         }
@@ -678,7 +678,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext
{
             val row = reader.getCurrentValue.asInstanceOf[InternalRow]
             result += row.getString(0)
           }
-          assert(data.map(_._2) == result)
+          assert(data.map(_._2).toSet == result.toSet)
         } finally {
           reader.close()
         }
@@ -695,7 +695,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext
{
             val v = (row.getString(0), row.getInt(1))
             result += v
           }
-          assert(data.map { x => (x._2, x._1) } == result)
+          assert(data.map { x => (x._2, x._1) }.toSet == result.toSet)
         } finally {
           reader.close()
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala
index 8bd736b..fff0f82 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala
@@ -95,7 +95,7 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext {
       df1.write.option("compression", "gzip").mode("overwrite").text(path)
       // On reading through wholetext mode, one file will be read as a single row, i.e. not
       // delimited by "next line" character.
-      val expected = Row(Range(0, 1000).mkString("", "\n", "\n"))
+      val expected = Row(df1.collect().map(_.getString(0)).mkString("", "\n", "\n"))
       Seq(10, 100, 1000).foreach { bytes =>
         withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) {
           val df2 = spark.read.option("wholetext", "true").format("text").load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/94c67a76/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index 9137d65..1248c67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
 
       var expectedEventsForPartition0 = Seq(
         ForeachSinkSuite.Open(partition = 0, version = 0),
-        ForeachSinkSuite.Process(value = 1),
+        ForeachSinkSuite.Process(value = 2),
         ForeachSinkSuite.Process(value = 3),
         ForeachSinkSuite.Close(None)
       )
       var expectedEventsForPartition1 = Seq(
         ForeachSinkSuite.Open(partition = 1, version = 0),
-        ForeachSinkSuite.Process(value = 2),
+        ForeachSinkSuite.Process(value = 1),
         ForeachSinkSuite.Process(value = 4),
         ForeachSinkSuite.Close(None)
       )
@@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
       val allEvents = ForeachSinkSuite.allEvents()
       assert(allEvents.size === 1)
       assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0))
-      assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1))
+      assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2))
 
       // `close` should be called with the error
       val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message