spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-19659] Fetch big blocks to disk when shuffle-read.
Date Thu, 25 May 2017 08:11:59 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.2 b52a06d70 -> 8896c4ee9


[SPARK-19659] Fetch big blocks to disk when shuffle-read.

## What changes were proposed in this pull request?

Currently the whole block is fetched into memory(off heap by default) when shuffle-read. A
block is defined by (shuffleId, mapId, reduceId). Thus it can be large when skew situations.
If OOM happens during shuffle read, job will be killed and users will be notified to "Consider
boosting spark.yarn.executor.memoryOverhead". Adjusting parameter and allocating more memory
can resolve the OOM. However the approach is not perfectly suitable for production environment,
especially for data warehouse.
Using Spark SQL as data engine in warehouse, users hope to have a unified parameter(e.g. memory)
but less resource wasted(resource is allocated but not used). The hope is strong especially
when migrating data engine to Spark from another one(e.g. Hive). Tuning the parameter for
thousands of SQLs one by one is very time consuming.
It's not always easy to predict skew situations, when happen, it make sense to fetch remote
blocks to disk for shuffle-read, rather than kill the job because of OOM.

In this pr, I propose to fetch big blocks to disk(which is also mentioned in SPARK-3019):

1. Track average size and also the outliers(which are larger than 2*avgSize) in MapStatus;
2. Request memory from `MemoryManager` before fetch blocks and release the memory to `MemoryManager`
when `ManagedBuffer` is released.
3. Fetch remote blocks to disk when failing acquiring memory from `MemoryManager`, otherwise
fetch to memory.

This is an improvement for memory control when shuffle blocks and help to avoid OOM in scenarios
like below:
1. Single huge block;
2. Sizes of many blocks are underestimated in `MapStatus` and the actual footprint of blocks
is much larger than the estimated.

## How was this patch tested?
Added unit test in `MapStatusSuite` and `ShuffleBlockFetcherIteratorSuite`.

Author: jinxing <jinxing6042@126.com>

Closes #16989 from jinxing64/SPARK-19659.

(cherry picked from commit 3f94e64aa8fd806ae1fa0156d846ce96afacddd3)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8896c4ee
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8896c4ee
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8896c4ee

Branch: refs/heads/branch-2.2
Commit: 8896c4ee9ea315a7dcd1a05b7201e7ad0539a5ed
Parents: b52a06d
Author: jinxing <jinxing6042@126.com>
Authored: Thu May 25 16:11:30 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu May 25 16:11:51 2017 +0800

----------------------------------------------------------------------
 .../network/server/OneForOneStreamManager.java  | 21 +++++
 .../network/shuffle/ExternalShuffleClient.java  |  7 +-
 .../network/shuffle/OneForOneBlockFetcher.java  | 62 +++++++++++++-
 .../spark/network/shuffle/ShuffleClient.java    |  4 +-
 .../network/sasl/SaslIntegrationSuite.java      |  2 +-
 .../ExternalShuffleIntegrationSuite.java        |  2 +-
 .../shuffle/OneForOneBlockFetcherSuite.java     |  7 +-
 .../apache/spark/internal/config/package.scala  |  6 ++
 .../spark/network/BlockTransferService.scala    |  7 +-
 .../netty/NettyBlockTransferService.scala       |  7 +-
 .../spark/shuffle/BlockStoreShuffleReader.scala |  3 +-
 .../storage/ShuffleBlockFetcherIterator.scala   | 71 ++++++++++------
 .../apache/spark/MapOutputTrackerSuite.scala    |  2 +-
 .../netty/NettyBlockTransferSecuritySuite.scala |  2 +-
 .../spark/storage/BlockManagerSuite.scala       |  4 +-
 .../ShuffleBlockFetcherIteratorSuite.scala      | 86 ++++++++++++++++++--
 docs/configuration.md                           |  8 ++
 17 files changed, 254 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index ee367f9..ad8e8b4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -23,6 +23,8 @@ import java.util.Random;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import scala.Tuple2;
+
 import com.google.common.base.Preconditions;
 import io.netty.channel.Channel;
 import org.slf4j.Logger;
@@ -95,6 +97,25 @@ public class OneForOneStreamManager extends StreamManager {
   }
 
   @Override
+  public ManagedBuffer openStream(String streamChunkId) {
+    Tuple2<Long, Integer> streamIdAndChunkId = parseStreamChunkId(streamChunkId);
+    return getChunk(streamIdAndChunkId._1, streamIdAndChunkId._2);
+  }
+
+  public static String genStreamChunkId(long streamId, int chunkId) {
+    return String.format("%d_%d", streamId, chunkId);
+  }
+
+  public static Tuple2<Long, Integer> parseStreamChunkId(String streamChunkId) {
+    String[] array = streamChunkId.split("_");
+    assert array.length == 2:
+      "Stream id and chunk index should be specified when open stream for fetching block.";
+    long streamId = Long.valueOf(array[0]);
+    int chunkIndex = Integer.valueOf(array[1]);
+    return new Tuple2<>(streamId, chunkIndex);
+  }
+
+  @Override
   public void connectionTerminated(Channel channel) {
     // Close all streams which have been associated with the channel.
     for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 2c5827b..269fa72 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network.shuffle;
 
+import java.io.File;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.List;
@@ -86,14 +87,16 @@ public class ExternalShuffleClient extends ShuffleClient {
       int port,
       String execId,
       String[] blockIds,
-      BlockFetchingListener listener) {
+      BlockFetchingListener listener,
+      File[] shuffleFiles) {
     checkInit();
     logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
     try {
       RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
           (blockIds1, listener1) -> {
             TransportClient client = clientFactory.createClient(host, port);
-            new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start();
+            new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf,
+              shuffleFiles).start();
           };
 
       int maxRetries = conf.maxIORetries();

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 35f69fe..5f42875 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,19 +17,28 @@
 
 package org.apache.spark.network.shuffle;
 
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.nio.channels.WritableByteChannel;
 import java.util.Arrays;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
 import org.apache.spark.network.buffer.ManagedBuffer;
 import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
 import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
 import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.TransportConf;
 
 /**
  * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block,
and
@@ -48,6 +57,8 @@ public class OneForOneBlockFetcher {
   private final String[] blockIds;
   private final BlockFetchingListener listener;
   private final ChunkReceivedCallback chunkCallback;
+  private TransportConf transportConf = null;
+  private File[] shuffleFiles = null;
 
   private StreamHandle streamHandle = null;
 
@@ -56,12 +67,20 @@ public class OneForOneBlockFetcher {
       String appId,
       String execId,
       String[] blockIds,
-      BlockFetchingListener listener) {
+      BlockFetchingListener listener,
+      TransportConf transportConf,
+      File[] shuffleFiles) {
     this.client = client;
     this.openMessage = new OpenBlocks(appId, execId, blockIds);
     this.blockIds = blockIds;
     this.listener = listener;
     this.chunkCallback = new ChunkCallback();
+    this.transportConf = transportConf;
+    if (shuffleFiles != null) {
+      this.shuffleFiles = shuffleFiles;
+      assert this.shuffleFiles.length == blockIds.length:
+        "Number of shuffle files should equal to blocks";
+    }
   }
 
   /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block.
*/
@@ -100,7 +119,12 @@ public class OneForOneBlockFetcher {
           // Immediately request all chunks -- we expect that the total size of the request
is
           // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
           for (int i = 0; i < streamHandle.numChunks; i++) {
-            client.fetchChunk(streamHandle.streamId, i, chunkCallback);
+            if (shuffleFiles != null) {
+              client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId,
i),
+                new DownloadCallback(shuffleFiles[i], i));
+            } else {
+              client.fetchChunk(streamHandle.streamId, i, chunkCallback);
+            }
           }
         } catch (Exception e) {
           logger.error("Failed while starting block fetches after success", e);
@@ -126,4 +150,38 @@ public class OneForOneBlockFetcher {
       }
     }
   }
+
+  private class DownloadCallback implements StreamCallback {
+
+    private WritableByteChannel channel = null;
+    private File targetFile = null;
+    private int chunkIndex;
+
+    public DownloadCallback(File targetFile, int chunkIndex) throws IOException {
+      this.targetFile = targetFile;
+      this.channel = Channels.newChannel(new FileOutputStream(targetFile));
+      this.chunkIndex = chunkIndex;
+    }
+
+    @Override
+    public void onData(String streamId, ByteBuffer buf) throws IOException {
+      channel.write(buf);
+    }
+
+    @Override
+    public void onComplete(String streamId) throws IOException {
+      channel.close();
+      ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
+        targetFile.length());
+      listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
+    }
+
+    @Override
+    public void onFailure(String streamId, Throwable cause) throws IOException {
+      channel.close();
+      // On receipt of a failure, fail every block from chunkIndex onwards.
+      String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
+      failRemainingBlocks(remainingBlockIds, cause);
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
index f72ab40..978ff5a 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -18,6 +18,7 @@
 package org.apache.spark.network.shuffle;
 
 import java.io.Closeable;
+import java.io.File;
 
 /** Provides an interface for reading shuffle files, either from an Executor or external
service. */
 public abstract class ShuffleClient implements Closeable {
@@ -40,5 +41,6 @@ public abstract class ShuffleClient implements Closeable {
       int port,
       String execId,
       String[] blockIds,
-      BlockFetchingListener listener);
+      BlockFetchingListener listener,
+      File[] shuffleFiles);
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index c0e170e..0c054fc 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -204,7 +204,7 @@ public class SaslIntegrationSuite {
 
       String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" };
       OneForOneBlockFetcher fetcher =
-          new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener);
+          new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null);
       fetcher.start();
       blockFetchLatch.await();
       checkSecurityException(exception.get());

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 7a33b68..d1d8f5b 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -158,7 +158,7 @@ public class ExternalShuffleIntegrationSuite {
             }
           }
         }
-      });
+      }, null);
 
     if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
       fail("Timeout getting response from the server");

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index 3e51fea..61d8221 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -46,8 +46,13 @@ import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
 import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
 
 public class OneForOneBlockFetcherSuite {
+
+  private static final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
+
   @Test
   public void testFetchOne() {
     LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
@@ -126,7 +131,7 @@ public class OneForOneBlockFetcherSuite {
     BlockFetchingListener listener = mock(BlockFetchingListener.class);
     String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
     OneForOneBlockFetcher fetcher =
-      new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);
+      new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null);
 
     // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId
123
     doAnswer(invocationOnMock -> {

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/main/scala/org/apache/spark/internal/config/package.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index e193ed2..f8139b7 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -287,4 +287,10 @@ package object config {
       .bytesConf(ByteUnit.BYTE)
       .createWithDefault(100 * 1024 * 1024)
 
+  private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
+    ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
+      .doc("The blocks of a shuffle request will be fetched to disk when size of the request
is " +
+        "above this threshold. This is to avoid a giant request takes too much memory.")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("200m")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index cb9d389..6860214 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.network
 
-import java.io.Closeable
+import java.io.{Closeable, File}
 import java.nio.ByteBuffer
 
 import scala.concurrent.{Future, Promise}
@@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable
with Lo
       port: Int,
       execId: String,
       blockIds: Array[String],
-      listener: BlockFetchingListener): Unit
+      listener: BlockFetchingListener,
+      shuffleFiles: Array[File]): Unit
 
   /**
    * Upload a single block to a remote node, available only after [[init]] is invoked.
@@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable
with Lo
           ret.flip()
           result.success(new NioManagedBuffer(ret))
         }
-      })
+      }, shuffleFiles = null)
     ThreadUtils.awaitResult(result.future, Duration.Inf)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index b75e91b..b13a9c6 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network.netty
 
+import java.io.File
 import java.nio.ByteBuffer
 
 import scala.collection.JavaConverters._
@@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService(
       port: Int,
       execId: String,
       blockIds: Array[String],
-      listener: BlockFetchingListener): Unit = {
+      listener: BlockFetchingListener,
+      shuffleFiles: Array[File]): Unit = {
     logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
     try {
       val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
         override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener)
{
           val client = clientFactory.createClient(host, port)
-          new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+          new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener,
+            transportConf, shuffleFiles).start()
         }
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index ba3e0e3..2fbac79 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.shuffle
 
 import org.apache.spark._
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
 import org.apache.spark.util.CompletionIterator
@@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
       // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
       SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
       SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+      SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
       SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
 
     val serializerInstance = dep.serializer.newInstance()

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index f890611..ee35060 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.{InputStream, IOException}
+import java.io.{File, InputStream, IOException}
 import java.nio.ByteBuffer
 import java.util.concurrent.LinkedBlockingQueue
 import javax.annotation.concurrent.GuardedBy
@@ -52,6 +52,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
  * @param streamWrapper A function to wrap the returned input stream.
  * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
  * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to
memory.
  * @param detectCorrupt whether to detect any corruption in fetched blocks.
  */
 private[spark]
@@ -63,6 +64,7 @@ final class ShuffleBlockFetcherIterator(
     streamWrapper: (BlockId, InputStream) => InputStream,
     maxBytesInFlight: Long,
     maxReqsInFlight: Int,
+    maxReqSizeShuffleToMem: Long,
     detectCorrupt: Boolean)
   extends Iterator[(BlockId, InputStream)] with Logging {
 
@@ -129,6 +131,12 @@ final class ShuffleBlockFetcherIterator(
   @GuardedBy("this")
   private[this] var isZombie = false
 
+  /**
+   * A set to store the files used for shuffling remote huge blocks. Files in this set will
be
+   * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
+   */
+  val shuffleFilesSet = mutable.HashSet[File]()
+
   initialize()
 
   // Decrements the buffer reference count.
@@ -163,6 +171,11 @@ final class ShuffleBlockFetcherIterator(
         case _ =>
       }
     }
+    shuffleFilesSet.foreach { file =>
+      if (!file.delete()) {
+        logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath());
+      }
+    }
   }
 
   private[this] def sendRequest(req: FetchRequest) {
@@ -175,33 +188,45 @@ final class ShuffleBlockFetcherIterator(
     val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
     val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
     val blockIds = req.blocks.map(_._1.toString)
-
     val address = req.address
-    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
-      new BlockFetchingListener {
-        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
-          // Only add the buffer to results queue if the iterator is not zombie,
-          // i.e. cleanup() has not been called yet.
-          ShuffleBlockFetcherIterator.this.synchronized {
-            if (!isZombie) {
-              // Increment the ref count because we need to pass this to a different thread.
-              // This needs to be released after use.
-              buf.retain()
-              remainingBlocks -= blockId
-              results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId),
buf,
-                remainingBlocks.isEmpty))
-              logDebug("remainingBlocks: " + remainingBlocks)
-            }
+
+    val blockFetchingListener = new BlockFetchingListener {
+      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+        // Only add the buffer to results queue if the iterator is not zombie,
+        // i.e. cleanup() has not been called yet.
+        ShuffleBlockFetcherIterator.this.synchronized {
+          if (!isZombie) {
+            // Increment the ref count because we need to pass this to a different thread.
+            // This needs to be released after use.
+            buf.retain()
+            remainingBlocks -= blockId
+            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId),
buf,
+              remainingBlocks.isEmpty))
+            logDebug("remainingBlocks: " + remainingBlocks)
           }
-          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
         }
+        logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+      }
 
-        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
-          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}",
e)
-          results.put(new FailureFetchResult(BlockId(blockId), address, e))
-        }
+      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+        logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}",
e)
+        results.put(new FailureFetchResult(BlockId(blockId), address, e))
       }
-    )
+    }
+
+    // Shuffle remote blocks to disk when the request is too large.
+    // TODO: Encryption and compression should be considered.
+    if (req.size > maxReqSizeShuffleToMem) {
+      val shuffleFiles = blockIds.map {
+        bId => blockManager.diskBlockManager.createTempLocalBlock()._2
+      }.toArray
+      shuffleFilesSet ++= shuffleFiles
+      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+        blockFetchingListener, shuffleFiles)
+    } else {
+      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+        blockFetchingListener, null)
+    }
   }
 
   private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index bb24c6c..71bedda 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.mockito.Matchers.{any, isA}
+import org.mockito.Matchers.any
 import org.mockito.Mockito._
 
 import org.apache.spark.broadcast.BroadcastManager

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
index 792a1d7..474e301 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -165,7 +165,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar
wi
         override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
           promise.success(data.retain())
         }
-      })
+      }, null)
 
     ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS))
     promise.future.value.get

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index a8b9604..9d7a869 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.storage
 
+import java.io.File
 import java.nio.ByteBuffer
 
 import scala.collection.mutable.ArrayBuffer
@@ -1265,7 +1266,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
         port: Int,
         execId: String,
         blockIds: Array[String],
-        listener: BlockFetchingListener): Unit = {
+        listener: BlockFetchingListener,
+        shuffleFiles: Array[File]): Unit = {
       listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 9900d1e..1f813a9 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.storage
 
 import java.io.{File, InputStream, IOException}
+import java.util.UUID
 import java.util.concurrent.Semaphore
 
 import scala.concurrent.ExecutionContext.Implicits.global
@@ -44,7 +45,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
   /** Creates a mock [[BlockTransferService]] that returns data from the given map. */
   private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService
= {
     val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit]
{
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
       override def answer(invocation: InvocationOnMock): Unit = {
         val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
         val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
@@ -106,6 +108,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       (_, in) => in,
       48 * 1024 * 1024,
       Int.MaxValue,
+      Int.MaxValue,
       true)
 
     // 3 local blocks fetched in initialization
@@ -134,7 +137,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     // 3 local blocks, and 2 remote blocks
     // (but from the same block manager so one call to fetchBlocks)
     verify(blockManager, times(3)).getBlockData(any())
-    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
+    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any())
   }
 
   test("release current unexhausted buffer in case the task completes early") {
@@ -153,7 +156,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val sem = new Semaphore(0)
 
     val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit]
{
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
       override def answer(invocation: InvocationOnMock): Unit = {
         val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
         Future {
@@ -181,6 +185,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       (_, in) => in,
       48 * 1024 * 1024,
       Int.MaxValue,
+      Int.MaxValue,
       true)
 
     verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
@@ -218,7 +223,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val sem = new Semaphore(0)
 
     val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit]
{
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
       override def answer(invocation: InvocationOnMock): Unit = {
         val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
         Future {
@@ -246,6 +252,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       (_, in) => in,
       48 * 1024 * 1024,
       Int.MaxValue,
+      Int.MaxValue,
       true)
 
     // Continue only after the mock calls onBlockFetchFailure
@@ -281,7 +288,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
 
     val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit]
{
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
       override def answer(invocation: InvocationOnMock): Unit = {
         val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
         Future {
@@ -309,6 +317,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       (_, in) => new LimitedInputStream(in, 100),
       48 * 1024 * 1024,
       Int.MaxValue,
+      Int.MaxValue,
       true)
 
     // Continue only after the mock calls onBlockFetchFailure
@@ -318,7 +327,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val (id1, _) = iterator.next()
     assert(id1 === ShuffleBlockId(0, 0, 0))
 
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit]
{
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
       override def answer(invocation: InvocationOnMock): Unit = {
         val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
         Future {
@@ -359,7 +369,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
 
     val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit]
{
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
       override def answer(invocation: InvocationOnMock): Unit = {
         val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
         Future {
@@ -387,6 +398,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       (_, in) => new LimitedInputStream(in, 100),
       48 * 1024 * 1024,
       Int.MaxValue,
+      Int.MaxValue,
       false)
 
     // Continue only after the mock calls onBlockFetchFailure
@@ -401,4 +413,64 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     assert(id3 === ShuffleBlockId(0, 2, 0))
   }
 
+  test("Blocks should be shuffled to disk when size of the request is above the" +
+    " threshold(maxReqSizeShuffleToMem).") {
+    val blockManager = mock(classOf[BlockManager])
+    val localBmId = BlockManagerId("test-client", "test-client", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+
+    val diskBlockManager = mock(classOf[DiskBlockManager])
+    doReturn{
+      var blockId = new TempLocalBlockId(UUID.randomUUID())
+      (blockId, new File(blockId.name))
+    }.when(diskBlockManager).createTempLocalBlock()
+    doReturn(diskBlockManager).when(blockManager).diskBlockManager
+
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val remoteBlocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
+    val transfer = mock(classOf[BlockTransferService])
+    var shuffleFiles: Array[File] = null
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
+        override def answer(invocation: InvocationOnMock): Unit = {
+          val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+          shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]]
+          Future {
+            listener.onBlockFetchSuccess(
+              ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
+          }
+        }
+      })
+
+    val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq))
+    // Set maxReqSizeShuffleToMem to be 200.
+    val iterator1 = new ShuffleBlockFetcherIterator(
+      TaskContext.empty(),
+      transfer,
+      blockManager,
+      blocksByAddress1,
+      (_, in) => in,
+      Int.MaxValue,
+      Int.MaxValue,
+      200,
+      true)
+    assert(shuffleFiles === null)
+
+    val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
+    // Set maxReqSizeShuffleToMem to be 200.
+    val iterator2 = new ShuffleBlockFetcherIterator(
+      TaskContext.empty(),
+      transfer,
+      blockManager,
+      blocksByAddress2,
+      (_, in) => in,
+      Int.MaxValue,
+      Int.MaxValue,
+      200,
+      true)
+    assert(shuffleFiles != null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8896c4ee/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index a6b6d5d..0771e36 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -520,6 +520,14 @@ Apart from these, the following properties are also available, and may
be useful
   </td>
 </tr>
 <tr>
+  <td><code>spark.reducer.maxReqSizeShuffleToMem</code></td>
+  <td>200m</td>
+  <td>
+    The blocks of a shuffle request will be fetched to disk when size of the request is above
+    this threshold. This is to avoid a giant request takes too much memory.
+  </td>
+</tr>
+<tr>
   <td><code>spark.shuffle.compress</code></td>
   <td>true</td>
   <td>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message