spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From andrewo...@apache.org
Subject spark git commit: [SPARK-8029] robust shuffle writer (for 1.5 branch)
Date Fri, 13 Nov 2015 20:59:31 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 3676d4c4d -> 330961bbf


[SPARK-8029] robust shuffle writer (for 1.5 branch)

Currently, all the shuffle writer will write to target path directly, the file could be corrupted
by other attempt of the same partition on the same executor. They should write to temporary
file then rename to target path, as what we do in output committer. In order to make the rename
atomic, the temporary file should be created in the same local directory (FileSystem).

This PR is based on #9214 , thanks to squito

Author: Davies Liu <davies@databricks.com>

Closes #9686 from davies/writer_1.5 and squashes the following commits:

e95fcf5 [Davies Liu] fix test
a6d569e [Davies Liu] fix consolidate
7e83298 [Davies Liu] robust shuffle writer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/330961bb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/330961bb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/330961bb

Branch: refs/heads/branch-1.5
Commit: 330961bbfcd36a6634549c663209a9f6d0e115ca
Parents: 3676d4c
Author: Davies Liu <davies@databricks.com>
Authored: Fri Nov 13 12:59:28 2015 -0800
Committer: Andrew Or <andrew@databricks.com>
Committed: Fri Nov 13 12:59:28 2015 -0800

----------------------------------------------------------------------
 .../shuffle/unsafe/UnsafeShuffleWriter.java     |  10 +-
 .../shuffle/FileShuffleBlockResolver.scala      |  19 +---
 .../shuffle/IndexShuffleBlockResolver.scala     | 103 ++++++++++++++++--
 .../spark/shuffle/hash/HashShuffleWriter.scala  |  27 +++++
 .../spark/shuffle/sort/SortShuffleWriter.scala  |  11 +-
 .../spark/storage/DiskBlockObjectWriter.scala   |   2 +-
 .../scala/org/apache/spark/util/Utils.scala     |  11 +-
 .../unsafe/UnsafeShuffleWriterSuite.java        |   6 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   | 109 ++++++++++++++++++-
 .../shuffle/hash/HashShuffleManagerSuite.scala  |   1 +
 10 files changed, 262 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index 2389c28..c79bb91 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -55,6 +55,7 @@ import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
 
 @Private
 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -217,8 +218,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K,
V> {
     final SpillInfo[] spills = sorter.closeAndGetSpills();
     sorter = null;
     final long[] partitionLengths;
+    final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    final File tmp = Utils.tempFileWith(output);
     try {
-      partitionLengths = mergeSpills(spills);
+      partitionLengths = mergeSpills(spills, tmp);
     } finally {
       for (SpillInfo spill : spills) {
         if (spill.file.exists() && ! spill.file.delete()) {
@@ -226,7 +229,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K,
V> {
         }
       }
     }
-    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
     mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
   }
 
@@ -259,8 +262,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K,
V> {
    *
    * @return the partition lengths in the merged file.
    */
-  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
-    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+  private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
     final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
     final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
     final boolean fastMergeEnabled =

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index f6a96d8..40b1138 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -23,15 +23,15 @@ import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConversions._
 
-import org.apache.spark.{Logging, SparkConf, SparkEnv}
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup
 import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
 import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
@@ -73,7 +73,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
 
   // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
   // TODO: Remove this once the shuffle file consolidation feature is stable.
-  private val consolidateShuffleFiles =
+  val consolidateShuffleFiles =
     conf.getBoolean("spark.shuffle.consolidateFiles", false)
 
   // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
@@ -124,17 +124,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
         Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
           val blockFile = blockManager.diskBlockManager.getFile(blockId)
-          // Because of previous failures, the shuffle file may already exist on this machine.
-          // If so, remove it.
-          if (blockFile.exists) {
-            if (blockFile.delete()) {
-              logInfo(s"Removed existing shuffle file $blockFile")
-            } else {
-              logWarning(s"Failed to remove existing shuffle file $blockFile")
-            }
-          }
-          blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
-            writeMetrics)
+          val tmp = Utils.tempFileWith(blockFile)
+          blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
         }
       }
       // Creating the file to write to and creating a disk writer both involve interacting
with

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d0163d3..7f71832 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -21,13 +21,12 @@ import java.io._
 
 import com.google.common.io.ByteStreams
 
-import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
-
-import IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
 
 /**
  * Create and maintain the shuffle blocks' mapping between logic block and physical file
location.
@@ -40,9 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
  */
 // Note: Changes to the format in this file should be kept in sync with
 // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
-private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
{
+private[spark] class IndexShuffleBlockResolver(
+    conf: SparkConf,
+    _blockManager: BlockManager = null)
+  extends ShuffleBlockResolver
+  with Logging {
 
-  private lazy val blockManager = SparkEnv.get.blockManager
+  private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
 
   private val transportConf = SparkTransportConf.fromSparkConf(conf)
 
@@ -70,13 +73,68 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends
ShuffleB
   }
 
   /**
+   * Check whether the given index and data files match each other.
+   * If so, return the partition lengths in the data file. Otherwise return null.
+   */
+  private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] =
{
+    // the index file should have `block + 1` longs as offset.
+    if (index.length() != (blocks + 1) * 8) {
+      return null
+    }
+    val lengths = new Array[Long](blocks)
+    // Read the lengths of blocks
+    val in = try {
+      new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
+    } catch {
+      case e: IOException =>
+        return null
+    }
+    try {
+      // Convert the offsets into lengths of each block
+      var offset = in.readLong()
+      if (offset != 0L) {
+        return null
+      }
+      var i = 0
+      while (i < blocks) {
+        val off = in.readLong()
+        lengths(i) = off - offset
+        offset = off
+        i += 1
+      }
+    } catch {
+      case e: IOException =>
+        return null
+    } finally {
+      in.close()
+    }
+
+    // the size of data file should match with index file
+    if (data.length() == lengths.sum) {
+      lengths
+    } else {
+      null
+    }
+  }
+
+  /**
    * Write an index file with the offsets of each block, plus a final offset at the end for
the
    * end of the output file. This will be used by getBlockData to figure out where each block
    * begins and ends.
+   *
+   * It will commit the data and index file as an atomic operation, use the existing ones,
or
+   * replace them with new ones.
+   *
+   * Note: the `lengths` will be updated to match the existing index file if use the existing
ones.
    * */
-  def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
+  def writeIndexFileAndCommit(
+      shuffleId: Int,
+      mapId: Int,
+      lengths: Array[Long],
+      dataTmp: File): Unit = {
     val indexFile = getIndexFile(shuffleId, mapId)
-    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+    val indexTmp = Utils.tempFileWith(indexFile)
+    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
     Utils.tryWithSafeFinally {
       // We take in lengths of each block, need to convert it to offsets.
       var offset = 0L
@@ -88,6 +146,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends
ShuffleB
     } {
       out.close()
     }
+
+    val dataFile = getDataFile(shuffleId, mapId)
+    // There is only one IndexShuffleBlockResolver per executor, this synchronization make
sure
+    // the following check and rename are atomic.
+    synchronized {
+      val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
+      if (existingLengths != null) {
+        // Another attempt for the same task has already written our map outputs successfully,
+        // so just use the existing partition lengths and delete our temporary map outputs.
+        System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+        if (dataTmp != null && dataTmp.exists()) {
+          dataTmp.delete()
+        }
+        indexTmp.delete()
+      } else {
+        // This is the first successful attempt in writing the map outputs for this task,
+        // so override any existing index and data files with the ones we wrote.
+        if (indexFile.exists()) {
+          indexFile.delete()
+        }
+        if (dataFile.exists()) {
+          dataFile.delete()
+        }
+        if (!indexTmp.renameTo(indexFile)) {
+          throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
+        }
+        if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile))
{
+          throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
+        }
+      }
+    }
   }
 
   override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 41df70c..e044516 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle.hash
 
+import java.io.IOException
+
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,31 @@ private[spark] class HashShuffleWriter[K, V](
       writer.commitAndClose()
       writer.fileSegment().length
     }
+    if (!shuffleBlockResolver.consolidateShuffleFiles) {
+      // rename all shuffle files to final paths
+      // Note: there is only one ShuffleBlockResolver in executor
+      shuffleBlockResolver.synchronized {
+        shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
+          val output = blockManager.diskBlockManager.getFile(writer.blockId)
+          if (sizes(i) > 0) {
+            if (output.exists()) {
+              // Use length of existing file and delete our own temporary one
+              sizes(i) = output.length()
+              writer.file.delete()
+            } else {
+              // Commit by renaming our temporary file to something the fetcher expects
+              if (!writer.file.renameTo(output)) {
+                throw new IOException(s"fail to rename ${writer.file} to $output")
+              }
+            }
+          } else {
+            if (output.exists()) {
+              output.delete()
+            }
+          }
+        }
+      }
+    }
     MapStatus(blockManager.shuffleServerId, sizes)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 5865e76..09c8649 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -21,8 +21,9 @@ import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
 import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class SortShuffleWriter[K, V, C](
@@ -75,11 +76,11 @@ private[spark] class SortShuffleWriter[K, V, C](
     // Don't bother including the time to open the merged output file in the shuffle write
time,
     // because it just opens a single file, so is typically too fast to measure accurately
     // (see SPARK-3570).
-    val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+    val tmp = Utils.tempFileWith(output)
     val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
-    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
-    shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
-
+    val partitionLengths = sorter.writePartitionedFile(blockId, context, tmp)
+    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths,
tmp)
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 49d9154..b0ba6df 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.Utils
  */
 private[spark] class DiskBlockObjectWriter(
     val blockId: BlockId,
-    file: File,
+    val file: File,
     serializerInstance: SerializerInstance,
     bufferSize: Int,
     compressStream: OutputStream => OutputStream,

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 2dd2c03..9dce33c 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,8 +21,8 @@ import java.io._
 import java.lang.management.ManagementFactory
 import java.net._
 import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
 import java.util.concurrent._
+import java.util.{Locale, Properties, Random, UUID}
 import javax.net.ssl.HttpsURLConnection
 
 import scala.collection.JavaConversions._
@@ -30,7 +30,7 @@ import scala.collection.Map
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 import scala.reflect.ClassTag
-import scala.util.{Failure, Success, Try}
+import scala.util.Try
 import scala.util.control.{ControlThrowable, NonFatal}
 
 import com.google.common.io.{ByteStreams, Files}
@@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
 import org.apache.log4j.PropertyConfigurator
 import org.eclipse.jetty.util.MultiException
 import org.json4s._
-
 import tachyon.TachyonURI
 import tachyon.client.{TachyonFS, TachyonFile}
 
@@ -2152,6 +2151,12 @@ private[spark] object Utils extends Logging {
       conf.getInt("spark.executor.instances", 0) == 0
   }
 
+  /**
+   * Returns a path of temporary file which is in the same directory with `path`.
+   */
+  def tempFileWith(path: File): File = {
+    new File(path.getAbsolutePath + "." + UUID.randomUUID())
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index a266b0c..7054fdc 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -174,9 +174,13 @@ public class UnsafeShuffleWriterSuite {
       @Override
       public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
         partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+        File tmp = (File) invocationOnMock.getArguments()[3];
+        mergedOutputFile.delete();
+        tmp.renameTo(mergedOutputFile);
         return null;
       }
-    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+    }).when(shuffleBlockResolver)
+      .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
 
     when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
       new Answer<Tuple2<TempShuffleBlockId, File>>() {

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index d91b799..fddde94 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,13 +17,17 @@
 
 package org.apache.spark
 
+import java.util.concurrent.{Callable, CyclicBarrier, ExecutorService, Executors}
+
 import org.scalatest.Matchers
 
 import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
 import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd}
 import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
+import org.apache.spark.shuffle.ShuffleWriter
+import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId}
+import org.apache.spark.unsafe.memory.{HeapMemoryAllocator, ExecutorMemoryManager, TaskMemoryManager}
 import org.apache.spark.util.MutablePair
 
 abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext {
@@ -315,6 +319,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with
LocalSparkC
     assert(metrics.bytesWritten === metrics.byresRead)
     assert(metrics.bytesWritten > 0)
   }
+
+  test("multiple simultaneous attempts for one task (SPARK-8029)") {
+    sc = new SparkContext("local", "test", conf)
+    val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+    val manager = sc.env.shuffleManager
+    val executionMemoryManager = new ExecutorMemoryManager(new HeapMemoryAllocator)
+    val taskMemoryManager = new TaskMemoryManager(executionMemoryManager)
+    val metricsSystem = sc.env.metricsSystem
+    val shuffleMapRdd = new MyRDD(sc, 1, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
+    val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+
+    // first attempt -- its successful
+    val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
+      new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val data1 = (1 to 10).map { x => x -> x}
+
+    // second attempt -- also successful.  We'll write out different data,
+    // just to simulate the fact that the records may get written differently
+    // depending on what gets spilled, what gets combined, etc.
+    val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
+      new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val data2 = (11 to 20).map { x => x -> x}
+
+    // interleave writes of both attempts -- we want to test that both attempts can occur
+    // simultaneously, and everything is still OK
+
+    def writeAndClose(
+      writer: ShuffleWriter[Int, Int])(
+      iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+      val files = writer.write(iter)
+      writer.stop(true)
+    }
+    val interleaver = new InterleaveIterators(
+      data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+    val (mapOutput1, mapOutput2) = interleaver.run()
+
+    // check that we can read the map output and it has the right data
+    assert(mapOutput1.isDefined)
+    assert(mapOutput2.isDefined)
+    assert(mapOutput1.get.location === mapOutput2.get.location)
+    assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0))
+
+    // register one of the map outputs -- doesn't matter which one
+    mapOutput1.foreach { case mapStatus =>
+      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+    }
+
+    val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
+      new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val readData = reader.read().toIndexedSeq
+    assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
+
+    manager.unregisterShuffle(0)
+  }
+}
+
+/**
+ * Utility to help tests make sure that we can process two different iterators simultaneously
+ * in different threads.  This makes sure that in your test, you don't completely process
data1 with
+ * f1 before processing data2 with f2 (or vice versa).  It adds a barrier so that the functions
only
+ * process one element, before pausing to wait for the other function to "catch up".
+ */
+class InterleaveIterators[T, R](
+  data1: Seq[T],
+  f1: Iterator[T] => R,
+  data2: Seq[T],
+  f2: Iterator[T] => R) {
+
+  require(data1.size == data2.size)
+
+  val barrier = new CyclicBarrier(2)
+  class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] {
+    def hasNext: Boolean = sub.hasNext
+
+    def next: E = {
+      barrier.await()
+      sub.next()
+    }
+  }
+
+  val c1 = new Callable[R] {
+    override def call(): R = f1(new BarrierIterator(1, data1.iterator))
+  }
+  val c2 = new Callable[R] {
+    override def call(): R = f2(new BarrierIterator(2, data2.iterator))
+  }
+
+  val e: ExecutorService = Executors.newFixedThreadPool(2)
+
+  def run(): (R, R) = {
+    val future1 = e.submit(c1)
+    val future2 = e.submit(c2)
+    val r1 = future1.get()
+    val r2 = future2.get()
+    e.shutdown()
+    (r1, r2)
+  }
 }
 
 object ShuffleSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/330961bb/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index 491dc36..7bcbbdf 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -46,6 +46,7 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext
{
     // an object is written. So if the codepaths assume writeObject is end of data, this
should
     // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
     conf.set("spark.serializer.objectStreamReset", "1")
+    conf.set("spark.shuffle.consolidateFiles", "true")
     conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message