spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From a...@apache.org
Subject spark git commit: [SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints
Date Mon, 20 Apr 2015 16:54:33 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1be207078 -> 968ad9721


[SPARK-7003] Improve reliability of connection failure detection between Netty block transfer
service endpoints

Currently we rely on the assumption that an exception will be raised and the channel closed
if two endpoints cannot communicate over a Netty TCP channel. However, this guarantee does
not hold in all network environments, and [SPARK-6962](https://issues.apache.org/jira/browse/SPARK-6962)
seems to point to a case where only the server side of the connection detected a fault.

This patch improves robustness of fetch/rpc requests by having an explicit timeout in the
transport layer which closes the connection if there is a period of inactivity while there
are outstanding requests.

NB: This patch is actually only around 50 lines added if you exclude the testing-related code.

Author: Aaron Davidson <aaron@databricks.com>

Closes #5584 from aarondav/timeout and squashes the following commits:

8699680 [Aaron Davidson] Address Reynold's comments
37ce656 [Aaron Davidson] [SPARK-7003] Improve reliability of connection failure detection
between Netty block transfer service endpoints


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/968ad972
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/968ad972
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/968ad972

Branch: refs/heads/master
Commit: 968ad972175390bb0a96918fd3c7b318d70fa466
Parents: 1be2070
Author: Aaron Davidson <aaron@databricks.com>
Authored: Mon Apr 20 09:54:21 2015 -0700
Committer: Aaron Davidson <aaron@databricks.com>
Committed: Mon Apr 20 09:54:21 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/network/TransportContext.java  |   5 +-
 .../client/TransportResponseHandler.java        |  14 +-
 .../network/server/TransportChannelHandler.java |  33 ++-
 .../spark/network/util/MapConfigProvider.java   |  41 +++
 .../apache/spark/network/util/NettyUtils.java   |   2 +-
 .../network/RequestTimeoutIntegrationSuite.java | 277 +++++++++++++++++++
 .../network/TransportClientFactorySuite.java    |  21 +-
 7 files changed, 375 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/main/java/org/apache/spark/network/TransportContext.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index f0a89c9..3fe69b1 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -22,6 +22,7 @@ import java.util.List;
 import com.google.common.collect.Lists;
 import io.netty.channel.Channel;
 import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.timeout.IdleStateHandler;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -106,6 +107,7 @@ public class TransportContext {
         .addLast("encoder", encoder)
         .addLast("frameDecoder", NettyUtils.createFrameDecoder())
         .addLast("decoder", decoder)
+        .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs()
/ 1000))
         // NOTE: Chunks are currently guaranteed to be returned in the order of request,
but this
         // would require more logic to guarantee if this were not part of the same event
loop.
         .addLast("handler", channelHandler);
@@ -126,7 +128,8 @@ public class TransportContext {
     TransportClient client = new TransportClient(channel, responseHandler);
     TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
       rpcHandler);
-    return new TransportChannelHandler(client, responseHandler, requestHandler);
+    return new TransportChannelHandler(client, responseHandler, requestHandler,
+      conf.connectionTimeoutMs());
   }
 
   public TransportConf getConf() { return conf; }

http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 2044afb..94fc21a 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -20,8 +20,8 @@ package org.apache.spark.network.client;
 import java.io.IOException;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
 
-import com.google.common.annotations.VisibleForTesting;
 import io.netty.channel.Channel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage>
{
 
   private final Map<Long, RpcResponseCallback> outstandingRpcs;
 
+  /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent.
*/
+  private final AtomicLong timeOfLastRequestNs;
+
   public TransportResponseHandler(Channel channel) {
     this.channel = channel;
     this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
     this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+    this.timeOfLastRequestNs = new AtomicLong(0);
   }
 
   public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback)
{
+    timeOfLastRequestNs.set(System.nanoTime());
     outstandingFetches.put(streamChunkId, callback);
   }
 
@@ -65,6 +70,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage>
{
   }
 
   public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+    timeOfLastRequestNs.set(System.nanoTime());
     outstandingRpcs.put(requestId, callback);
   }
 
@@ -161,8 +167,12 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage>
{
   }
 
   /** Returns total number of outstanding requests (fetch requests + rpcs) */
-  @VisibleForTesting
   public int numOutstandingRequests() {
     return outstandingFetches.size() + outstandingRpcs.size();
   }
+
+  /** Returns the time in nanoseconds of when the last request was sent out. */
+  public long getTimeOfLastRequestNs() {
+    return timeOfLastRequestNs.get();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index e491367..8e0ee70 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -19,6 +19,8 @@ package org.apache.spark.network.server;
 
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -40,6 +42,11 @@ import org.apache.spark.network.util.NettyUtils;
  * Client.
  * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
  * for the Client's responses to the Server's requests.
+ *
+ * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}.
+ * We consider a connection timed out if there are outstanding fetch or RPC requests but
no traffic
+ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will
not
+ * timeout if the client is continuously sending but getting no responses, for simplicity.
  */
 public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
   private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
@@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
   private final TransportClient client;
   private final TransportResponseHandler responseHandler;
   private final TransportRequestHandler requestHandler;
+  private final long requestTimeoutNs;
 
   public TransportChannelHandler(
       TransportClient client,
       TransportResponseHandler responseHandler,
-      TransportRequestHandler requestHandler) {
+      TransportRequestHandler requestHandler,
+      long requestTimeoutMs) {
     this.client = client;
     this.responseHandler = responseHandler;
     this.requestHandler = requestHandler;
+    this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
   }
 
   public TransportClient getClient() {
@@ -93,4 +103,25 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
       responseHandler.handle((ResponseMessage) request);
     }
   }
+
+  /** Triggered based on events from an {@link io.netty.handler.timeout.IdleStateHandler}.
*/
+  @Override
+  public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception
{
+    if (evt instanceof IdleStateEvent) {
+      IdleStateEvent e = (IdleStateEvent) evt;
+      // See class comment for timeout semantics. In addition to ensuring we only timeout
while
+      // there are outstanding requests, we also do a secondary consistency check to ensure
+      // there's no race between the idle timeout and incrementing the numOutstandingRequests.
+      boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0;
+      boolean isActuallyOverdue =
+        System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
+      if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue)
{
+        String address = NettyUtils.getRemoteAddress(ctx.channel());
+        logger.error("Connection to {} has been quiet for {} ms while there are outstanding
" +
+          "requests. Assuming connection is dead; please adjust spark.network.timeout if
this " +
+          "is wrong.", address, requestTimeoutNs / 1000 / 1000);
+        ctx.close();
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
new file mode 100644
index 0000000..668d235
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.collect.Maps;
+
+import java.util.Map;
+import java.util.NoSuchElementException;
+
+/** ConfigProvider based on a Map (copied in the constructor). */
+public class MapConfigProvider extends ConfigProvider {
+  private final Map<String, String> config;
+
+  public MapConfigProvider(Map<String, String> config) {
+    this.config = Maps.newHashMap(config);
+  }
+
+  @Override
+  public String get(String name) {
+    String value = config.get(name);
+    if (value == null) {
+      throw new NoSuchElementException(name);
+    }
+    return value;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index dabd626..26c6399 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -98,7 +98,7 @@ public class NettyUtils {
     return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
   }
 
-  /** Returns the remote address on the channel or "&lt;remote address&gt;" if none
exists. */
+  /** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none
exists. */
   public static String getRemoteAddress(Channel channel) {
     if (channel != null && channel.remoteAddress() != null) {
       return channel.remoteAddress().toString();

http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
new file mode 100644
index 0000000..84ebb33
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.junit.*;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Suite which ensures that requests that go without a response for the network timeout period
are
+ * failed, and the connection closed.
+ *
+ * In this suite, we use 2 seconds as the connection timeout, with some slack given in the
tests,
+ * to ensure stability in different test environments.
+ */
+public class RequestTimeoutIntegrationSuite {
+
+  private TransportServer server;
+  private TransportClientFactory clientFactory;
+
+  private StreamManager defaultManager;
+  private TransportConf conf;
+
+  // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever.
+  private final int FOREVER = 60 * 1000;
+
+  @Before
+  public void setUp() throws Exception {
+    Map<String, String> configMap = Maps.newHashMap();
+    configMap.put("spark.shuffle.io.connectionTimeout", "2s");
+    conf = new TransportConf(new MapConfigProvider(configMap));
+
+    defaultManager = new StreamManager() {
+      @Override
+      public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+        throw new UnsupportedOperationException();
+      }
+    };
+  }
+
+  @After
+  public void tearDown() {
+    if (server != null) {
+      server.close();
+    }
+    if (clientFactory != null) {
+      clientFactory.close();
+    }
+  }
+
+  // Basic suite: First request completes quickly, and second waits for longer than network
timeout.
+  @Test
+  public void timeoutInactiveRequests() throws Exception {
+    final Semaphore semaphore = new Semaphore(1);
+    final byte[] response = new byte[16];
+    RpcHandler handler = new RpcHandler() {
+      @Override
+      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback)
{
+        try {
+          semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+          callback.onSuccess(response);
+        } catch (InterruptedException e) {
+          // do nothing
+        }
+      }
+
+      @Override
+      public StreamManager getStreamManager() {
+        return defaultManager;
+      }
+    };
+
+    TransportContext context = new TransportContext(conf, handler);
+    server = context.createServer();
+    clientFactory = context.createClientFactory();
+    TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+    // First completes quickly (semaphore starts at 1).
+    TestCallback callback0 = new TestCallback();
+    synchronized (callback0) {
+      client.sendRpc(new byte[0], callback0);
+      callback0.wait(FOREVER);
+      assert (callback0.success.length == response.length);
+    }
+
+    // Second times out after 2 seconds, with slack. Must be IOException.
+    TestCallback callback1 = new TestCallback();
+    synchronized (callback1) {
+      client.sendRpc(new byte[0], callback1);
+      callback1.wait(4 * 1000);
+      assert (callback1.failure != null);
+      assert (callback1.failure instanceof IOException);
+    }
+    semaphore.release();
+  }
+
+  // A timeout will cause the connection to be closed, invalidating the current TransportClient.
+  // It should be the case that requesting a client from the factory produces a new, valid
one.
+  @Test
+  public void timeoutCleanlyClosesClient() throws Exception {
+    final Semaphore semaphore = new Semaphore(0);
+    final byte[] response = new byte[16];
+    RpcHandler handler = new RpcHandler() {
+      @Override
+      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback)
{
+        try {
+          semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
+          callback.onSuccess(response);
+        } catch (InterruptedException e) {
+          // do nothing
+        }
+      }
+
+      @Override
+      public StreamManager getStreamManager() {
+        return defaultManager;
+      }
+    };
+
+    TransportContext context = new TransportContext(conf, handler);
+    server = context.createServer();
+    clientFactory = context.createClientFactory();
+
+    // First request should eventually fail.
+    TransportClient client0 =
+      clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+    TestCallback callback0 = new TestCallback();
+    synchronized (callback0) {
+      client0.sendRpc(new byte[0], callback0);
+      callback0.wait(FOREVER);
+      assert (callback0.failure instanceof IOException);
+      assert (!client0.isActive());
+    }
+
+    // Increment the semaphore and the second request should succeed quickly.
+    semaphore.release(2);
+    TransportClient client1 =
+      clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+    TestCallback callback1 = new TestCallback();
+    synchronized (callback1) {
+      client1.sendRpc(new byte[0], callback1);
+      callback1.wait(FOREVER);
+      assert (callback1.success.length == response.length);
+      assert (callback1.failure == null);
+    }
+  }
+
+  // The timeout is relative to the LAST request sent, which is kinda weird, but still.
+  // This test also makes sure the timeout works for Fetch requests as well as RPCs.
+  @Test
+  public void furtherRequestsDelay() throws Exception {
+    final byte[] response = new byte[16];
+    final StreamManager manager = new StreamManager() {
+      @Override
+      public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+        Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
+        return new NioManagedBuffer(ByteBuffer.wrap(response));
+      }
+    };
+    RpcHandler handler = new RpcHandler() {
+      @Override
+      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback)
{
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public StreamManager getStreamManager() {
+        return manager;
+      }
+    };
+
+    TransportContext context = new TransportContext(conf, handler);
+    server = context.createServer();
+    clientFactory = context.createClientFactory();
+    TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+    // Send one request, which will eventually fail.
+    TestCallback callback0 = new TestCallback();
+    client.fetchChunk(0, 0, callback0);
+    Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+    // Send a second request before the first has failed.
+    TestCallback callback1 = new TestCallback();
+    client.fetchChunk(0, 1, callback1);
+    Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS);
+
+    synchronized (callback0) {
+      // not complete yet, but should complete soon
+      assert (callback0.success == null && callback0.failure == null);
+      callback0.wait(2 * 1000);
+      assert (callback0.failure instanceof IOException);
+    }
+
+    synchronized (callback1) {
+      // failed at same time as previous
+      assert (callback0.failure instanceof IOException);
+    }
+  }
+
+  /**
+   * Callback which sets 'success' or 'failure' on completion.
+   * Additionally notifies all waiters on this callback when invoked.
+   */
+  class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
+
+    byte[] success;
+    Throwable failure;
+
+    @Override
+    public void onSuccess(byte[] response) {
+      synchronized(this) {
+        success = response;
+        this.notifyAll();
+      }
+    }
+
+    @Override
+    public void onFailure(Throwable e) {
+      synchronized(this) {
+        failure = e;
+        this.notifyAll();
+      }
+    }
+
+    @Override
+    public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+      synchronized(this) {
+        try {
+          success = buffer.nioByteBuffer().array();
+          this.notifyAll();
+        } catch (IOException e) {
+          // weird
+        }
+      }
+    }
+
+    @Override
+    public void onFailure(int chunkIndex, Throwable e) {
+      synchronized(this) {
+        failure = e;
+        this.notifyAll();
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/968ad972/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 416dc1b..35de5e5 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -20,10 +20,11 @@ package org.apache.spark.network;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.HashSet;
-import java.util.NoSuchElementException;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import com.google.common.collect.Maps;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -36,9 +37,9 @@ import org.apache.spark.network.client.TransportClientFactory;
 import org.apache.spark.network.server.NoOpRpcHandler;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.ConfigProvider;
-import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.MapConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 
 public class TransportClientFactorySuite {
@@ -70,16 +71,10 @@ public class TransportClientFactorySuite {
    */
   private void testClientReuse(final int maxConnections, boolean concurrent)
     throws IOException, InterruptedException {
-    TransportConf conf = new TransportConf(new ConfigProvider() {
-      @Override
-      public String get(String name) {
-        if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
-          return Integer.toString(maxConnections);
-        } else {
-          throw new NoSuchElementException();
-        }
-      }
-    });
+
+    Map<String, String> configMap = Maps.newHashMap();
+    configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections));
+    TransportConf conf = new TransportConf(new MapConfigProvider(configMap));
 
     RpcHandler rpcHandler = new NoOpRpcHandler();
     TransportContext context = new TransportContext(conf, rpcHandler);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message