spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-4188] [Core] Perform network-level retry of shuffle file fetches
Date Fri, 07 Nov 2014 02:39:18 GMT
Repository: spark
Updated Branches:
  refs/heads/master 6e9ef10fd -> f165b2bbf


[SPARK-4188] [Core] Perform network-level retry of shuffle file fetches

This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException.

This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load.

TODO:
- [x] unit tests
- [x] put in ExternalShuffleClient too

Author: Aaron Davidson <aaron@databricks.com>

Closes #3101 from aarondav/retry and squashes the following commits:

72a2a32 [Aaron Davidson] Add that we should remove the condition around the retry thingy
c7fd107 [Aaron Davidson] Fix unit tests
e80e4c2 [Aaron Davidson] Address initial comments
6f594cd [Aaron Davidson] Fix unit test
05ff43c [Aaron Davidson] Add to external shuffle client and add unit test
66e5a24 [Aaron Davidson] [SPARK-4238] [Core] Perform network-level retry of shuffle file fetches


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f165b2bb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f165b2bb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f165b2bb

Branch: refs/heads/master
Commit: f165b2bbf5d4acf34d826fa55b900f5bbc295654
Parents: 6e9ef10
Author: Aaron Davidson <aaron@databricks.com>
Authored: Thu Nov 6 18:39:14 2014 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Thu Nov 6 18:39:14 2014 -0800

----------------------------------------------------------------------
 .../netty/NettyBlockTransferService.scala       |  21 +-
 .../spark/network/client/TransportClient.java   |  16 +-
 .../network/client/TransportClientFactory.java  |  13 +-
 .../client/TransportResponseHandler.java        |   3 +-
 .../spark/network/protocol/MessageEncoder.java  |   2 +-
 .../spark/network/server/TransportServer.java   |   8 +-
 .../apache/spark/network/util/NettyUtils.java   |  14 +-
 .../spark/network/util/TransportConf.java       |  17 +
 .../network/TransportClientFactorySuite.java    |   7 +-
 .../network/shuffle/ExternalShuffleClient.java  |  31 +-
 .../network/shuffle/OneForOneBlockFetcher.java  |   9 +-
 .../network/shuffle/RetryingBlockFetcher.java   | 234 ++++++++++++++
 .../network/sasl/SaslIntegrationSuite.java      |   4 +-
 .../ExternalShuffleIntegrationSuite.java        |  18 +-
 .../shuffle/ExternalShuffleSecuritySuite.java   |   6 +-
 .../shuffle/RetryingBlockFetcherSuite.java      | 310 +++++++++++++++++++
 16 files changed, 668 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 0d1fc81..b937ea8 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal
 import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
 import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
 import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.storage.{BlockId, StorageLevel}
 import org.apache.spark.util.Utils
@@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
       listener: BlockFetchingListener): Unit = {
     logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
     try {
-      val client = clientFactory.createClient(host, port)
-      new OneForOneBlockFetcher(client, blockIds.toArray, listener)
-        .start(OpenBlocks(blockIds.map(BlockId.apply)))
+      val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+        override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+          val client = clientFactory.createClient(host, port)
+          new OneForOneBlockFetcher(client, blockIds.toArray, listener)
+            .start(OpenBlocks(blockIds.map(BlockId.apply)))
+        }
+      }
+
+      val maxRetries = transportConf.maxIORetries()
+      if (maxRetries > 0) {
+        // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+        // a bug in this code. We should remove the if statement once we're sure of the stability.
+        new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+      } else {
+        blockFetchStarter.createAndStart(blockIds, listener)
+      }
     } catch {
       case e: Exception =>
         logError("Exception while beginning fetchBlocks", e)

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index a08cee0..4e94411 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -18,7 +18,9 @@
 package org.apache.spark.network.client;
 
 import java.io.Closeable;
+import java.io.IOException;
 import java.util.UUID;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 
 import com.google.common.base.Objects;
@@ -116,8 +118,12 @@ public class TransportClient implements Closeable {
               serverAddr, future.cause());
             logger.error(errorMsg, future.cause());
             handler.removeFetchRequest(streamChunkId);
-            callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
             channel.close();
+            try {
+              callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
+            } catch (Exception e) {
+              logger.error("Uncaught exception in RPC response callback handler!", e);
+            }
           }
         }
       });
@@ -147,8 +153,12 @@ public class TransportClient implements Closeable {
               serverAddr, future.cause());
             logger.error(errorMsg, future.cause());
             handler.removeRpcRequest(requestId);
-            callback.onFailure(new RuntimeException(errorMsg, future.cause()));
             channel.close();
+            try {
+              callback.onFailure(new IOException(errorMsg, future.cause()));
+            } catch (Exception e) {
+              logger.error("Uncaught exception in RPC response callback handler!", e);
+            }
           }
         }
       });
@@ -175,6 +185,8 @@ public class TransportClient implements Closeable {
 
     try {
       return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+    } catch (ExecutionException e) {
+      throw Throwables.propagate(e.getCause());
     } catch (Exception e) {
       throw Throwables.propagate(e);
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 1723fed..397d3a8 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -18,12 +18,12 @@
 package org.apache.spark.network.client;
 
 import java.io.Closeable;
+import java.io.IOException;
 import java.lang.reflect.Field;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
 import java.util.List;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicReference;
 
 import com.google.common.base.Preconditions;
@@ -44,7 +44,6 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.network.TransportContext;
 import org.apache.spark.network.server.TransportChannelHandler;
 import org.apache.spark.network.util.IOMode;
-import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -93,15 +92,17 @@ public class TransportClientFactory implements Closeable {
    *
    * Concurrency: This method is safe to call from multiple threads.
    */
-  public TransportClient createClient(String remoteHost, int remotePort) {
+  public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
     // Get connection from the connection pool first.
     // If it is not found or not active, create a new one.
     final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
     TransportClient cachedClient = connectionPool.get(address);
     if (cachedClient != null) {
       if (cachedClient.isActive()) {
+        logger.trace("Returning cached connection to {}: {}", address, cachedClient);
         return cachedClient;
       } else {
+        logger.info("Found inactive connection to {}, closing it.", address);
         connectionPool.remove(address, cachedClient); // Remove inactive clients.
       }
     }
@@ -133,10 +134,10 @@ public class TransportClientFactory implements Closeable {
     long preConnect = System.currentTimeMillis();
     ChannelFuture cf = bootstrap.connect(address);
     if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
-      throw new RuntimeException(
+      throw new IOException(
         String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
     } else if (cf.cause() != null) {
-      throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+      throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
     }
 
     TransportClient client = clientRef.get();
@@ -198,7 +199,7 @@ public class TransportClientFactory implements Closeable {
    */
   private PooledByteBufAllocator createPooledByteBufAllocator() {
     return new PooledByteBufAllocator(
-        PlatformDependent.directBufferPreferred(),
+        conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
         getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
         getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
         getPrivateStaticField("DEFAULT_PAGE_SIZE"),

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index d896559..2044afb 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network.client;
 
+import java.io.IOException;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
@@ -94,7 +95,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
       String remoteAddress = NettyUtils.getRemoteAddress(channel);
       logger.error("Still have {} requests outstanding when connection from {} is closed",
         numOutstandingRequests(), remoteAddress);
-      failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+      failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 4cb8bec..91d1e8a 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -66,7 +66,7 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
     // All messages have the frame length, message type, and message itself.
     int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
     long frameLength = headerLength + bodyLength;
-    ByteBuf header = ctx.alloc().buffer(headerLength);
+    ByteBuf header = ctx.alloc().heapBuffer(headerLength);
     header.writeLong(frameLength);
     msgType.encode(header);
     in.encode(header);

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index 70da48c..579676c 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -28,6 +28,7 @@ import io.netty.channel.ChannelInitializer;
 import io.netty.channel.ChannelOption;
 import io.netty.channel.EventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -71,11 +72,14 @@ public class TransportServer implements Closeable {
       NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
     EventLoopGroup workerGroup = bossGroup;
 
+    PooledByteBufAllocator allocator = new PooledByteBufAllocator(
+      conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());
+
     bootstrap = new ServerBootstrap()
       .group(bossGroup, workerGroup)
       .channel(NettyUtils.getServerChannelClass(ioMode))
-      .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-      .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
+      .option(ChannelOption.ALLOCATOR, allocator)
+      .childOption(ChannelOption.ALLOCATOR, allocator);
 
     if (conf.backLog() > 0) {
       bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index b187234..2a7664f 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -37,13 +37,17 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
  * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
  */
 public class NettyUtils {
-  /** Creates a Netty EventLoopGroup based on the IOMode. */
-  public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
-
-    ThreadFactory threadFactory = new ThreadFactoryBuilder()
+  /** Creates a new ThreadFactory which prefixes each thread with the given name. */
+  public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
+    return new ThreadFactoryBuilder()
       .setDaemon(true)
-      .setNameFormat(threadPrefix + "-%d")
+      .setNameFormat(threadPoolPrefix + "-%d")
       .build();
+  }
+
+  /** Creates a Netty EventLoopGroup based on the IOMode. */
+  public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+    ThreadFactory threadFactory = createThreadFactory(threadPrefix);
 
     switch (mode) {
       case NIO:

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 823790d..787a8f0 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -30,6 +30,11 @@ public class TransportConf {
   /** IO mode: nio or epoll */
   public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
 
+  /** If true, we will prefer allocating off-heap byte buffers within Netty. */
+  public boolean preferDirectBufs() {
+    return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
+  }
+
   /** Connect timeout in secs. Default 120 secs. */
   public int connectionTimeoutMs() {
     return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
@@ -58,4 +63,16 @@ public class TransportConf {
 
   /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
   public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
+
+  /**
+   * Max number of times we will try IO exceptions (such as connection timeouts) per request.
+   * If set to 0, we will not do any retries.
+   */
+  public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); }
+
+  /**
+   * Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
+   * Only relevant if maxIORetries > 0.
+   */
+  public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 5a10fdb..822bef1 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network;
 
+import java.io.IOException;
 import java.util.concurrent.TimeoutException;
 
 import org.junit.After;
@@ -57,7 +58,7 @@ public class TransportClientFactorySuite {
   }
 
   @Test
-  public void createAndReuseBlockClients() throws TimeoutException {
+  public void createAndReuseBlockClients() throws IOException {
     TransportClientFactory factory = context.createClientFactory();
     TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
     TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
@@ -70,7 +71,7 @@ public class TransportClientFactorySuite {
   }
 
   @Test
-  public void neverReturnInactiveClients() throws Exception {
+  public void neverReturnInactiveClients() throws IOException, InterruptedException {
     TransportClientFactory factory = context.createClientFactory();
     TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
     c1.close();
@@ -88,7 +89,7 @@ public class TransportClientFactorySuite {
   }
 
   @Test
-  public void closeBlockClientsWithFactory() throws TimeoutException {
+  public void closeBlockClientsWithFactory() throws IOException {
     TransportClientFactory factory = context.createClientFactory();
     TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
     TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 3aa95d0..27884b8 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network.shuffle;
 
+import java.io.IOException;
 import java.util.List;
 
 import com.google.common.collect.Lists;
@@ -76,17 +77,33 @@ public class ExternalShuffleClient extends ShuffleClient {
 
   @Override
   public void fetchBlocks(
-      String host,
-      int port,
-      String execId,
+      final String host,
+      final int port,
+      final String execId,
       String[] blockIds,
       BlockFetchingListener listener) {
     assert appId != null : "Called before init()";
     logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
     try {
-      TransportClient client = clientFactory.createClient(host, port);
-      new OneForOneBlockFetcher(client, blockIds, listener)
-        .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+      RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
+        new RetryingBlockFetcher.BlockFetchStarter() {
+          @Override
+          public void createAndStart(String[] blockIds, BlockFetchingListener listener)
+              throws IOException {
+            TransportClient client = clientFactory.createClient(host, port);
+            new OneForOneBlockFetcher(client, blockIds, listener)
+              .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds));
+          }
+        };
+
+      int maxRetries = conf.maxIORetries();
+      if (maxRetries > 0) {
+        // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+        // a bug in this code. We should remove the if statement once we're sure of the stability.
+        new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start();
+      } else {
+        blockFetchStarter.createAndStart(blockIds, listener);
+      }
     } catch (Exception e) {
       logger.error("Exception while beginning fetchBlocks", e);
       for (String blockId : blockIds) {
@@ -108,7 +125,7 @@ public class ExternalShuffleClient extends ShuffleClient {
       String host,
       int port,
       String execId,
-      ExecutorShuffleInfo executorInfo) {
+      ExecutorShuffleInfo executorInfo) throws IOException {
     assert appId != null : "Called before init()";
     TransportClient client = clientFactory.createClient(host, port);
     byte[] registerExecutorMessage =

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 39b6f30..9e77a1f 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -51,9 +51,6 @@ public class OneForOneBlockFetcher {
       TransportClient client,
       String[] blockIds,
       BlockFetchingListener listener) {
-    if (blockIds.length == 0) {
-      throw new IllegalArgumentException("Zero-sized blockIds array");
-    }
     this.client = client;
     this.blockIds = blockIds;
     this.listener = listener;
@@ -82,6 +79,10 @@ public class OneForOneBlockFetcher {
    * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling.
    */
   public void start(Object openBlocksMessage) {
+    if (blockIds.length == 0) {
+      throw new IllegalArgumentException("Zero-sized blockIds array");
+    }
+
     client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() {
       @Override
       public void onSuccess(byte[] response) {
@@ -95,7 +96,7 @@ public class OneForOneBlockFetcher {
             client.fetchChunk(streamHandle.streamId, i, chunkCallback);
           }
         } catch (Exception e) {
-          logger.error("Failed while starting block fetches", e);
+          logger.error("Failed while starting block fetches after success", e);
           failRemainingBlocks(blockIds, e);
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
new file mode 100644
index 0000000..f8a1a26
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to
+ * IOExceptions, which we hope are due to transient network conditions.
+ *
+ * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In
+ * particular, the listener will be invoked exactly once per blockId, with a success or failure.
+ */
+public class RetryingBlockFetcher {
+
+  /**
+   * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any
+   * remaining blocks.
+   */
+  public static interface BlockFetchStarter {
+    /**
+     * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous
+     * bootstrapping followed by fully asynchronous block fetching.
+     * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this
+     * method must throw an exception.
+     *
+     * This method should always attempt to get a new TransportClient from the
+     * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection
+     * issues.
+     */
+    void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException;
+  }
+
+  /** Shared executor service used for waiting and retrying. */
+  private static final ExecutorService executorService = Executors.newCachedThreadPool(
+    NettyUtils.createThreadFactory("Block Fetch Retry"));
+
+  private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class);
+
+  /** Used to initiate new Block Fetches on our remaining blocks. */
+  private final BlockFetchStarter fetchStarter;
+
+  /** Parent listener which we delegate all successful or permanently failed block fetches to. */
+  private final BlockFetchingListener listener;
+
+  /** Max number of times we are allowed to retry. */
+  private final int maxRetries;
+
+  /** Milliseconds to wait before each retry. */
+  private final int retryWaitTime;
+
+  // NOTE:
+  // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated
+  // while inside a synchronized block.
+  /** Number of times we've attempted to retry so far. */
+  private int retryCount = 0;
+
+  /**
+   * Set of all block ids which have not been fetched successfully or with a non-IO Exception.
+   * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet,
+   * input ordering is preserved, so we always request blocks in the same order the user provided.
+   */
+  private final LinkedHashSet<String> outstandingBlocksIds;
+
+  /**
+   * The BlockFetchingListener that is active with our current BlockFetcher.
+   * When we start a retry, we immediately replace this with a new Listener, which causes all any
+   * old Listeners to ignore all further responses.
+   */
+  private RetryingBlockFetchListener currentListener;
+
+  public RetryingBlockFetcher(
+      TransportConf conf,
+      BlockFetchStarter fetchStarter,
+      String[] blockIds,
+      BlockFetchingListener listener) {
+    this.fetchStarter = fetchStarter;
+    this.listener = listener;
+    this.maxRetries = conf.maxIORetries();
+    this.retryWaitTime = conf.ioRetryWaitTime();
+    this.outstandingBlocksIds = Sets.newLinkedHashSet();
+    Collections.addAll(outstandingBlocksIds, blockIds);
+    this.currentListener = new RetryingBlockFetchListener();
+  }
+
+  /**
+   * Initiates the fetch of all blocks provided in the constructor, with possible retries in the
+   * event of transient IOExceptions.
+   */
+  public void start() {
+    fetchAllOutstanding();
+  }
+
+  /**
+   * Fires off a request to fetch all blocks that have not been fetched successfully or permanently
+   * failed (i.e., by a non-IOException).
+   */
+  private void fetchAllOutstanding() {
+    // Start by retrieving our shared state within a synchronized block.
+    String[] blockIdsToFetch;
+    int numRetries;
+    RetryingBlockFetchListener myListener;
+    synchronized (this) {
+      blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]);
+      numRetries = retryCount;
+      myListener = currentListener;
+    }
+
+    // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails.
+    try {
+      fetchStarter.createAndStart(blockIdsToFetch, myListener);
+    } catch (Exception e) {
+      logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s",
+        blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e);
+
+      if (shouldRetry(e)) {
+        initiateRetry();
+      } else {
+        for (String bid : blockIdsToFetch) {
+          listener.onBlockFetchFailure(bid, e);
+        }
+      }
+    }
+  }
+
+  /**
+   * Lightweight method which initiates a retry in a different thread. The retry will involve
+   * calling fetchAllOutstanding() after a configured wait time.
+   */
+  private synchronized void initiateRetry() {
+    retryCount += 1;
+    currentListener = new RetryingBlockFetchListener();
+
+    logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms",
+      retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime);
+
+    executorService.submit(new Runnable() {
+      @Override
+      public void run() {
+        Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS);
+        fetchAllOutstanding();
+      }
+    });
+  }
+
+  /**
+   * Returns true if we should retry due a block fetch failure. We will retry if and only if
+   * the exception was an IOException and we haven't retried 'maxRetries' times already.
+   */
+  private synchronized boolean shouldRetry(Throwable e) {
+    boolean isIOException = e instanceof IOException
+      || (e.getCause() != null && e.getCause() instanceof IOException);
+    boolean hasRemainingRetries = retryCount < maxRetries;
+    return isIOException && hasRemainingRetries;
+  }
+
+  /**
+   * Our RetryListener intercepts block fetch responses and forwards them to our parent listener.
+   * Note that in the event of a retry, we will immediately replace the 'currentListener' field,
+   * indicating that any responses from non-current Listeners should be ignored.
+   */
+  private class RetryingBlockFetchListener implements BlockFetchingListener {
+    @Override
+    public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
+      // We will only forward this success message to our parent listener if this block request is
+      // outstanding and we are still the active listener.
+      boolean shouldForwardSuccess = false;
+      synchronized (RetryingBlockFetcher.this) {
+        if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+          outstandingBlocksIds.remove(blockId);
+          shouldForwardSuccess = true;
+        }
+      }
+
+      // Now actually invoke the parent listener, outside of the synchronized block.
+      if (shouldForwardSuccess) {
+        listener.onBlockFetchSuccess(blockId, data);
+      }
+    }
+
+    @Override
+    public void onBlockFetchFailure(String blockId, Throwable exception) {
+      // We will only forward this failure to our parent listener if this block request is
+      // outstanding, we are still the active listener, AND we cannot retry the fetch.
+      boolean shouldForwardFailure = false;
+      synchronized (RetryingBlockFetcher.this) {
+        if (this == currentListener && outstandingBlocksIds.contains(blockId)) {
+          if (shouldRetry(exception)) {
+            initiateRetry();
+          } else {
+            logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)",
+              blockId, retryCount), exception);
+            outstandingBlocksIds.remove(blockId);
+            shouldForwardFailure = true;
+          }
+        }
+      }
+
+      // Now actually invoke the parent listener, outside of the synchronized block.
+      if (shouldForwardFailure) {
+        listener.onBlockFetchFailure(blockId, exception);
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 8478120..d25283e 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -93,7 +93,7 @@ public class SaslIntegrationSuite {
   }
 
   @Test
-  public void testGoodClient() {
+  public void testGoodClient() throws IOException {
     clientFactory = context.createClientFactory(
       Lists.<TransportClientBootstrap>newArrayList(
         new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
@@ -119,7 +119,7 @@ public class SaslIntegrationSuite {
   }
 
   @Test
-  public void testNoSaslClient() {
+  public void testNoSaslClient() throws IOException {
     clientFactory = context.createClientFactory(
       Lists.<TransportClientBootstrap>newArrayList());
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 71e017b..06294fe 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -259,14 +259,20 @@ public class ExternalShuffleIntegrationSuite {
 
   @Test
   public void testFetchNoServer() throws Exception {
-    registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
-    FetchResult execFetch = fetchBlocks("exec-0",
-      new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */);
-    assertTrue(execFetch.successBlocks.isEmpty());
-    assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+    System.setProperty("spark.shuffle.io.maxRetries", "0");
+    try {
+      registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER));
+      FetchResult execFetch = fetchBlocks("exec-0",
+        new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */);
+      assertTrue(execFetch.successBlocks.isEmpty());
+      assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks);
+    } finally {
+      System.clearProperty("spark.shuffle.io.maxRetries");
+    }
   }
 
-  private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
+  private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
+      throws IOException {
     ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
     client.init(APP_ID);
     client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 4c18fcd..848c88f 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.shuffle;
 
+import java.io.IOException;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -54,7 +56,7 @@ public class ExternalShuffleSecuritySuite {
   }
 
   @Test
-  public void testValid() {
+  public void testValid() throws IOException {
     validate("my-app-id", "secret");
   }
 
@@ -77,7 +79,7 @@ public class ExternalShuffleSecuritySuite {
   }
 
   /** Creates an ExternalShuffleClient and attempts to register with the server. */
-  private void validate(String appId, String secretKey) {
+  private void validate(String appId, String secretKey) throws IOException {
     ExternalShuffleClient client =
       new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
     client.init(appId);

http://git-wip-us.apache.org/repos/asf/spark/blob/f165b2bb/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
new file mode 100644
index 0000000..0191fe5
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashSet;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Sets;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.Stubber;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter;
+
+/**
+ * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to
+ * fetch the lost blocks.
+ */
+public class RetryingBlockFetcherSuite {
+
+  ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
+  ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+  ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
+
+  @Before
+  public void beforeEach() {
+    System.setProperty("spark.shuffle.io.maxRetries", "2");
+    System.setProperty("spark.shuffle.io.retryWaitMs", "0");
+  }
+
+  @After
+  public void afterEach() {
+    System.clearProperty("spark.shuffle.io.maxRetries");
+    System.clearProperty("spark.shuffle.io.retryWaitMs");
+  }
+
+  @Test
+  public void testNoFailures() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // Immediately return both blocks successfully.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", block0)
+        .put("b1", block1)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener).onBlockFetchSuccess("b0", block0);
+    verify(listener).onBlockFetchSuccess("b1", block1);
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testUnrecoverableFailure() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // b0 throws a non-IOException error, so it will be failed without retry.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", new RuntimeException("Ouch!"))
+        .put("b1", block1)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any());
+    verify(listener).onBlockFetchSuccess("b1", block1);
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testSingleIOExceptionOnFirst() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // IOException will cause a retry. Since b0 fails, we will retry both.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", new IOException("Connection failed or something"))
+        .put("b1", block1)
+        .build(),
+      ImmutableMap.<String, Object>builder()
+        .put("b0", block0)
+        .put("b1", block1)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testSingleIOExceptionOnSecond() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // IOException will cause a retry. Since b1 fails, we will not retry b0.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", block0)
+        .put("b1", new IOException("Connection failed or something"))
+        .build(),
+      ImmutableMap.<String, Object>builder()
+        .put("b1", block1)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testTwoIOExceptions() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // b0's IOException will trigger retry, b1's will be ignored.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", new IOException())
+        .put("b1", new IOException())
+        .build(),
+      // Next, b0 is successful and b1 errors again, so we just request that one.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", block0)
+        .put("b1", new IOException())
+        .build(),
+      // b1 returns successfully within 2 retries.
+      ImmutableMap.<String, Object>builder()
+        .put("b1", block1)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1);
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testThreeIOExceptions() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // b0's IOException will trigger retry, b1's will be ignored.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", new IOException())
+        .put("b1", new IOException())
+        .build(),
+      // Next, b0 is successful and b1 errors again, so we just request that one.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", block0)
+        .put("b1", new IOException())
+        .build(),
+      // b1 errors again, but this was the last retry
+      ImmutableMap.<String, Object>builder()
+        .put("b1", new IOException())
+        .build(),
+      // This is not reached -- b1 has failed.
+      ImmutableMap.<String, Object>builder()
+        .put("b1", block1)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+    verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+    verifyNoMoreInteractions(listener);
+  }
+
+  @Test
+  public void testRetryAndUnrecoverable() throws IOException {
+    BlockFetchingListener listener = mock(BlockFetchingListener.class);
+
+    Map[] interactions = new Map[] {
+      // b0's IOException will trigger retry, subsequent messages will be ignored.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", new IOException())
+        .put("b1", new RuntimeException())
+        .put("b2", block2)
+        .build(),
+      // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry.
+      ImmutableMap.<String, Object>builder()
+        .put("b0", block0)
+        .put("b1", new RuntimeException())
+        .put("b2", new IOException())
+        .build(),
+      // b2 succeeds in its last retry.
+      ImmutableMap.<String, Object>builder()
+        .put("b2", block2)
+        .build(),
+    };
+
+    performInteractions(interactions, listener);
+
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0);
+    verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any());
+    verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2);
+    verifyNoMoreInteractions(listener);
+  }
+
+  /**
+   * Performs a set of interactions in response to block requests from a RetryingBlockFetcher.
+   * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction
+   * means "respond to the next block fetch request with these Successful buffers and these Failure
+   * exceptions". We verify that the expected block ids are exactly the ones requested.
+   *
+   * If multiple interactions are supplied, they will be used in order. This is useful for encoding
+   * retries -- the first interaction may include an IOException, which causes a retry of some
+   * subset of the original blocks in a second interaction.
+   */
+  @SuppressWarnings("unchecked")
+  private void performInteractions(final Map[] interactions, BlockFetchingListener listener)
+    throws IOException {
+
+    TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+    BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class);
+
+    Stubber stub = null;
+
+    // Contains all blockIds that are referenced across all interactions.
+    final LinkedHashSet<String> blockIds = Sets.newLinkedHashSet();
+
+    for (final Map<String, Object> interaction : interactions) {
+      blockIds.addAll(interaction.keySet());
+
+      Answer<Void> answer = new Answer<Void>() {
+        @Override
+        public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+          try {
+            // Verify that the RetryingBlockFetcher requested the expected blocks.
+            String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0];
+            String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]);
+            assertArrayEquals(desiredBlockIds, requestedBlockIds);
+
+            // Now actually invoke the success/failure callbacks on each block.
+            BlockFetchingListener retryListener =
+              (BlockFetchingListener) invocationOnMock.getArguments()[1];
+            for (Map.Entry<String, Object> block : interaction.entrySet()) {
+              String blockId = block.getKey();
+              Object blockValue = block.getValue();
+
+              if (blockValue instanceof ManagedBuffer) {
+                retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue);
+              } else if (blockValue instanceof Exception) {
+                retryListener.onBlockFetchFailure(blockId, (Exception) blockValue);
+              } else {
+                fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue);
+              }
+            }
+            return null;
+          } catch (Throwable e) {
+            e.printStackTrace();
+            throw e;
+          }
+        }
+      };
+
+      // This is either the first stub, or should be chained behind the prior ones.
+      if (stub == null) {
+        stub = doAnswer(answer);
+      } else {
+        stub.doAnswer(answer);
+      }
+    }
+
+    assert stub != null;
+    stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject());
+    String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
+    new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start();
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message