spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From van...@apache.org
Subject spark git commit: [SPARK-10997][CORE] Add "client mode" to netty rpc env.
Date Mon, 02 Nov 2015 18:26:43 GMT
Repository: spark
Updated Branches:
  refs/heads/master a930e624e -> 71d1c907d


[SPARK-10997][CORE] Add "client mode" to netty rpc env.

"Client mode" means the RPC env will not listen for incoming connections.
This allows certain processes in the Spark stack (such as Executors or
tha YARN client-mode AM) to act as pure clients when using the netty-based
RPC backend, reducing the number of sockets needed by the app and also the
number of open ports.

Client connections are also preferred when endpoints that actually have
a listening socket are involved; so, for example, if a Worker connects
to a Master and the Master needs to send a message to a Worker endpoint,
that client connection will be used, even though the Worker is also
listening for incoming connections.

With this change, the workaround for SPARK-10987 isn't necessary anymore, and
is removed. The AM connects to the driver in "client mode", and that connection
is used for all driver <-> AM communication, and so the AM is properly notified
when the connection goes down.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #9210 from vanzin/SPARK-10997.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/71d1c907
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/71d1c907
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/71d1c907

Branch: refs/heads/master
Commit: 71d1c907dec446db566b19f912159fd8f46deb7d
Parents: a930e62
Author: Marcelo Vanzin <vanzin@cloudera.com>
Authored: Mon Nov 2 10:26:36 2015 -0800
Committer: Marcelo Vanzin <vanzin@cloudera.com>
Committed: Mon Nov 2 10:26:36 2015 -0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/SparkEnv.scala  |   7 +-
 .../executor/CoarseGrainedExecutorBackend.scala |  20 +-
 .../scala/org/apache/spark/rpc/RpcEnv.scala     |   8 +-
 .../org/apache/spark/rpc/netty/Dispatcher.scala |   2 +-
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    | 245 ++++++++++---------
 .../org/apache/spark/rpc/netty/Outbox.scala     |  24 +-
 .../spark/rpc/netty/RpcEndpointAddress.scala    |  24 +-
 .../cluster/CoarseGrainedClusterMessage.scala   |  10 +-
 .../cluster/CoarseGrainedSchedulerBackend.scala |  18 +-
 .../cluster/YarnSchedulerBackend.scala          |   2 -
 .../org/apache/spark/rpc/RpcEnvSuite.scala      |  50 ++--
 .../apache/spark/rpc/akka/AkkaRpcEnvSuite.scala |  11 +-
 .../spark/rpc/netty/NettyRpcAddressSuite.scala  |   7 +-
 .../spark/rpc/netty/NettyRpcEnvSuite.scala      |   9 +-
 .../spark/rpc/netty/NettyRpcHandlerSuite.scala  |   8 +-
 network/yarn/pom.xml                            |   5 +
 .../spark/deploy/yarn/ApplicationMaster.scala   |   6 +-
 17 files changed, 266 insertions(+), 190 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 398e093..23ae936 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -252,7 +252,8 @@ object SparkEnv extends Logging {
 
     // Create the ActorSystem for Akka and get the port it binds to.
     val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
-    val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
+    val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager,
+      clientMode = !isDriver)
     val actorSystem: ActorSystem =
       if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
         rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
@@ -262,9 +263,11 @@ object SparkEnv extends Logging {
       }
 
     // Figure out which port Akka actually bound to in case the original port is 0 or occupied.
+    // In the non-driver case, the RPC env's address may be null since it may not be listening
+    // for incoming connections.
     if (isDriver) {
       conf.set("spark.driver.port", rpcEnv.address.port.toString)
-    } else {
+    } else if (rpcEnv.address != null) {
       conf.set("spark.executor.port", rpcEnv.address.port.toString)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index a9c6a05..c2ebf30 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -45,8 +45,6 @@ private[spark] class CoarseGrainedExecutorBackend(
     env: SparkEnv)
   extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
 
-  Utils.checkHostPort(hostPort, "Expected hostport")
-
   var executor: Executor = null
   @volatile var driver: Option[RpcEndpointRef] = None
 
@@ -80,9 +78,8 @@ private[spark] class CoarseGrainedExecutorBackend(
   }
 
   override def receive: PartialFunction[Any, Unit] = {
-    case RegisteredExecutor =>
+    case RegisteredExecutor(hostname) =>
       logInfo("Successfully registered with driver")
-      val (hostname, _) = Utils.parseHostPort(hostPort)
       executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
 
     case RegisterExecutorFailed(message) =>
@@ -163,7 +160,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
         hostname,
         port,
         executorConf,
-        new SecurityManager(executorConf))
+        new SecurityManager(executorConf),
+        clientMode = true)
       val driver = fetcher.setupEndpointRefByURI(driverUrl)
       val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
         Seq[(String, String)](("spark.app.id", appId))
@@ -188,12 +186,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
       val env = SparkEnv.createExecutorEnv(
         driverConf, executorId, hostname, port, cores, isLocal = false)
 
-      // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore.
-      val boundPort = env.conf.getInt("spark.executor.port", 0)
-      assert(boundPort != 0)
-
-      // Start the CoarseGrainedExecutorBackend endpoint.
-      val sparkHostPort = hostname + ":" + boundPort
+      // SparkEnv will set spark.executor.port if the rpc env is listening for incoming
+      // connections (e.g., if it's using akka). Otherwise, the executor is running in
+      // client mode only, and does not accept incoming connections.
+      val sparkHostPort = env.conf.getOption("spark.executor.port").map { port =>
+          hostname + ":" + port
+        }.orNull
       env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
         env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
       workerUrl.foreach { url =>

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 2c4a8b9..a560fd1 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -43,9 +43,10 @@ private[spark] object RpcEnv {
       host: String,
       port: Int,
       conf: SparkConf,
-      securityManager: SecurityManager): RpcEnv = {
+      securityManager: SecurityManager,
+      clientMode: Boolean = false): RpcEnv = {
     // Using Reflection to create the RpcEnv to avoid to depend on Akka directly
-    val config = RpcEnvConfig(conf, name, host, port, securityManager)
+    val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
     getRpcEnvFactory(conf).create(config)
   }
 }
@@ -139,4 +140,5 @@ private[spark] case class RpcEnvConfig(
     name: String,
     host: String,
     port: Int,
-    securityManager: SecurityManager)
+    securityManager: SecurityManager,
+    clientMode: Boolean)

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 7bf44a6..eb25d6c 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -55,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
   private var stopped = false
 
   def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
-    val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
+    val addr = RpcEndpointAddress(nettyEnv.address, name)
     val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
     synchronized {
       if (stopped) {

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 284284e..0909381 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -17,10 +17,12 @@
 package org.apache.spark.rpc.netty
 
 import java.io._
+import java.lang.{Boolean => JBoolean}
 import java.net.{InetSocketAddress, URI}
 import java.nio.ByteBuffer
 import java.util.concurrent._
 import java.util.concurrent.atomic.AtomicBoolean
+import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
@@ -29,6 +31,7 @@ import scala.reflect.ClassTag
 import scala.util.{DynamicVariable, Failure, Success}
 import scala.util.control.NonFatal
 
+import com.google.common.base.Preconditions
 import org.apache.spark.{Logging, SecurityManager, SparkConf}
 import org.apache.spark.network.TransportContext
 import org.apache.spark.network.client._
@@ -45,15 +48,14 @@ private[netty] class NettyRpcEnv(
     host: String,
     securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
 
-  // Override numConnectionsPerPeer to 1 for RPC.
   private val transportConf = SparkTransportConf.fromSparkConf(
     conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
     conf.getInt("spark.rpc.io.threads", 0))
 
   private val dispatcher: Dispatcher = new Dispatcher(this)
 
-  private val transportContext =
-    new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
+  private val transportContext = new TransportContext(transportConf,
+    new NettyRpcHandler(dispatcher, this))
 
   private val clientFactory = {
     val bootstraps: java.util.List[TransportClientBootstrap] =
@@ -95,7 +97,7 @@ private[netty] class NettyRpcEnv(
     }
   }
 
-  def start(port: Int): Unit = {
+  def startServer(port: Int): Unit = {
     val bootstraps: java.util.List[TransportServerBootstrap] =
       if (securityManager.isAuthenticationEnabled()) {
         java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
@@ -107,9 +109,9 @@ private[netty] class NettyRpcEnv(
       RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
   }
 
+  @Nullable
   override lazy val address: RpcAddress = {
-    require(server != null, "NettyRpcEnv has not yet started")
-    RpcAddress(host, server.getPort)
+    if (server != null) RpcAddress(host, server.getPort()) else null
   }
 
   override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
@@ -120,7 +122,7 @@ private[netty] class NettyRpcEnv(
     val addr = RpcEndpointAddress(uri)
     val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
     val verifier = new NettyRpcEndpointRef(
-      conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
+      conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this)
     verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
       if (find) {
         Future.successful(endpointRef)
@@ -135,28 +137,34 @@ private[netty] class NettyRpcEnv(
     dispatcher.stop(endpointRef)
   }
 
-  private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = {
-    val targetOutbox = {
-      val outbox = outboxes.get(address)
-      if (outbox == null) {
-        val newOutbox = new Outbox(this, address)
-        val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
-        if (oldOutbox == null) {
-          newOutbox
+  private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
+    if (receiver.client != null) {
+      receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+    } else {
+      require(receiver.address != null,
+        "Cannot send message to client endpoint with no listen address.")
+      val targetOutbox = {
+        val outbox = outboxes.get(receiver.address)
+        if (outbox == null) {
+          val newOutbox = new Outbox(this, receiver.address)
+          val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)
+          if (oldOutbox == null) {
+            newOutbox
+          } else {
+            oldOutbox
+          }
         } else {
-          oldOutbox
+          outbox
         }
+      }
+      if (stopped.get) {
+        // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
+        outboxes.remove(receiver.address)
+        targetOutbox.stop()
       } else {
-        outbox
+        targetOutbox.send(message)
       }
     }
-    if (stopped.get) {
-      // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
-      outboxes.remove(address)
-      targetOutbox.stop()
-    } else {
-      targetOutbox.send(message)
-    }
   }
 
   private[netty] def send(message: RequestMessage): Unit = {
@@ -174,17 +182,14 @@ private[netty] class NettyRpcEnv(
       }(ThreadUtils.sameThread)
     } else {
       // Message to a remote RPC endpoint.
-      postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
-
-        override def onFailure(e: Throwable): Unit = {
+      postToOutbox(message.receiver, OutboxMessage(serialize(message),
+        (e) => {
           logWarning(s"Exception when sending $message", e)
-        }
-
-        override def onSuccess(response: Array[Byte]): Unit = {
-          val ack = deserialize[Ack](response)
+        },
+        (client, response) => {
+          val ack = deserialize[Ack](client, response)
           logDebug(s"Receive ack from ${ack.sender}")
-        }
-      }))
+        }))
     }
   }
 
@@ -214,16 +219,14 @@ private[netty] class NettyRpcEnv(
           }
       }(ThreadUtils.sameThread)
     } else {
-      postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
-
-        override def onFailure(e: Throwable): Unit = {
+      postToOutbox(message.receiver, OutboxMessage(serialize(message),
+        (e) => {
           if (!promise.tryFailure(e)) {
             logWarning("Ignore Exception", e)
           }
-        }
-
-        override def onSuccess(response: Array[Byte]): Unit = {
-          val reply = deserialize[AskResponse](response)
+        },
+        (client, response) => {
+          val reply = deserialize[AskResponse](client, response)
           if (reply.reply.isInstanceOf[RpcFailure]) {
             if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
               logWarning(s"Ignore failure: ${reply.reply}")
@@ -231,8 +234,7 @@ private[netty] class NettyRpcEnv(
           } else if (!promise.trySuccess(reply.reply)) {
             logWarning(s"Ignore message: ${reply}")
           }
-        }
-      }))
+        }))
     }
     promise.future
   }
@@ -243,9 +245,11 @@ private[netty] class NettyRpcEnv(
       buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
   }
 
-  private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
-    deserialize { () =>
-      javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+  private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
+    NettyRpcEnv.currentClient.withValue(client) {
+      deserialize { () =>
+        javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+      }
     }
   }
 
@@ -254,7 +258,7 @@ private[netty] class NettyRpcEnv(
   }
 
   override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
-    new RpcEndpointAddress(address.host, address.port, endpointName).toString
+    new RpcEndpointAddress(address, endpointName).toString
 
   override def shutdown(): Unit = {
     cleanup()
@@ -297,6 +301,7 @@ private[netty] class NettyRpcEnv(
       deserializationAction()
     }
   }
+
 }
 
 private[netty] object NettyRpcEnv extends Logging {
@@ -312,6 +317,13 @@ private[netty] object NettyRpcEnv extends Logging {
    * }}}
    */
   private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null)
+
+  /**
+   * Similar to `currentEnv`, this variable references the client instance associated with an
+   * RPC, in case it's needed to find out the remote address during deserialization.
+   */
+  private[netty] val currentClient = new DynamicVariable[TransportClient](null)
+
 }
 
 private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
@@ -324,47 +336,68 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
       new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
     val nettyEnv =
       new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
-    val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
-      nettyEnv.start(actualPort)
-      (nettyEnv, actualPort)
-    }
-    try {
-      Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
-    } catch {
-      case NonFatal(e) =>
-        nettyEnv.shutdown()
-        throw e
+    if (!config.clientMode) {
+      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
+        nettyEnv.startServer(actualPort)
+        (nettyEnv, actualPort)
+      }
+      try {
+        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1
+      } catch {
+        case NonFatal(e) =>
+          nettyEnv.shutdown()
+          throw e
+      }
     }
+    nettyEnv
   }
 }
 
-private[netty] class NettyRpcEndpointRef(@transient private val conf: SparkConf)
+/**
+ * The NettyRpcEnv version of RpcEndpointRef.
+ *
+ * This class behaves differently depending on where it's created. On the node that "owns" the
+ * RpcEndpoint, it's a simple wrapper around the RpcEndpointAddress instance.
+ *
+ * On other machines that receive a serialized version of the reference, the behavior changes. The
+ * instance will keep track of the TransportClient that sent the reference, so that messages
+ * to the endpoint are sent over the client connection, instead of needing a new connection to
+ * be opened.
+ *
+ * The RpcAddress of this ref can be null; what that means is that the ref can only be used through
+ * a client connection, since the process hosting the endpoint is not listening for incoming
+ * connections. These refs should not be shared with 3rd parties, since they will not be able to
+ * send messages to the endpoint.
+ *
+ * @param conf Spark configuration.
+ * @param endpointAddress The address where the endpoint is listening.
+ * @param nettyEnv The RpcEnv associated with this ref.
+ * @param local Whether the referenced endpoint lives in the same process.
+ */
+private[netty] class NettyRpcEndpointRef(
+    @transient private val conf: SparkConf,
+    endpointAddress: RpcEndpointAddress,
+    @transient @volatile private var nettyEnv: NettyRpcEnv)
   extends RpcEndpointRef(conf) with Serializable with Logging {
 
-  @transient @volatile private var nettyEnv: NettyRpcEnv = _
+  @transient @volatile var client: TransportClient = _
 
-  @transient @volatile private var _address: RpcEndpointAddress = _
+  private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
+  private val _name = endpointAddress.name
 
-  def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
-    this(conf)
-    this._address = _address
-    this.nettyEnv = nettyEnv
-  }
-
-  override def address: RpcAddress = _address.toRpcAddress
+  override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
 
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject()
-    _address = in.readObject().asInstanceOf[RpcEndpointAddress]
     nettyEnv = NettyRpcEnv.currentEnv.value
+    client = NettyRpcEnv.currentClient.value
   }
 
   private def writeObject(out: ObjectOutputStream): Unit = {
     out.defaultWriteObject()
-    out.writeObject(_address)
   }
 
-  override def name: String = _address.name
+  override def name: String = _name
 
   override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
     val promise = Promise[Any]()
@@ -429,41 +462,43 @@ private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessa
 private[netty] case class RpcFailure(e: Throwable)
 
 /**
- * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast
- * network events and forward messages to [[Dispatcher]].
+ * Dispatches incoming RPCs to registered endpoints.
+ *
+ * The handler keeps track of all client instances that communicate with it, so that the RpcEnv
+ * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e.,
+ * one that is not listening for incoming connections, but rather needs to be contacted via the
+ * client socket).
+ *
+ * Events are sent on a per-connection basis, so if a client opens multiple connections to the
+ * RpcEnv, multiple connection / disconnection events will be created for that client (albeit
+ * with different `RpcAddress` information).
  */
 private[netty] class NettyRpcHandler(
     dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging {
 
-  private type ClientAddress = RpcAddress
-  private type RemoteEnvAddress = RpcAddress
-
-  // Store all client addresses and their NettyRpcEnv addresses.
-  // TODO: Is this even necessary?
-  @GuardedBy("this")
-  private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
+  // TODO: Can we add connection callback (channel registered) to the underlying framework?
+  // A variable to track whether we should dispatch the RemoteProcessConnected message.
+  private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()
 
   override def receive(
-      client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
-    val requestMessage = nettyEnv.deserialize[RequestMessage](message)
-    val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
+      client: TransportClient,
+      message: Array[Byte],
+      callback: RpcResponseCallback): Unit = {
+    val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
     assert(addr != null)
-    val remoteEnvAddress = requestMessage.senderAddress
     val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
-
-    // TODO: Can we add connection callback (channel registered) to the underlying framework?
-    // A variable to track whether we should dispatch the RemoteProcessConnected message.
-    var dispatchRemoteProcessConnected = false
-    synchronized {
-      if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
-        // clientAddr connects at the first time, fire "RemoteProcessConnected"
-        dispatchRemoteProcessConnected = true
-      }
+    if (clients.putIfAbsent(client, JBoolean.TRUE) == null) {
+      dispatcher.postToAll(RemoteProcessConnected(clientAddr))
     }
-    if (dispatchRemoteProcessConnected) {
-      dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
-    }
-    dispatcher.postRemoteMessage(requestMessage, callback)
+    val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
+    val messageToDispatch = if (requestMessage.senderAddress == null) {
+        // Create a new message with the socket address of the client as the sender.
+        RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
+          requestMessage.needReply)
+      } else {
+        requestMessage
+      }
+    dispatcher.postRemoteMessage(messageToDispatch, callback)
   }
 
   override def getStreamManager: StreamManager = new OneForOneStreamManager
@@ -472,15 +507,7 @@ private[netty] class NettyRpcHandler(
     val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     if (addr != null) {
       val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
-      val broadcastMessage =
-        synchronized {
-          remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _))
-        }
-      if (broadcastMessage.isEmpty) {
-        logError(cause.getMessage, cause)
-      } else {
-        dispatcher.postToAll(broadcastMessage.get)
-      }
+      dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))
     } else {
       // If the channel is closed before connecting, its remoteAddress will be null.
       // See java.net.Socket.getRemoteSocketAddress
@@ -493,15 +520,9 @@ private[netty] class NettyRpcHandler(
     val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
     if (addr != null) {
       val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
+      clients.remove(client)
       nettyEnv.removeOutbox(clientAddr)
-      val messageOpt: Option[RemoteProcessDisconnected] =
-      synchronized {
-        remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
-          remoteAddresses -= clientAddr
-          Some(RemoteProcessDisconnected(remoteEnvAddress))
-        }
-      }
-      messageOpt.foreach(dispatcher.postToAll)
+      dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
     } else {
       // If the channel is closed before connecting, its remoteAddress will be null. In this case,
       // we can ignore it since we don't fire "Associated".

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 7d9d593..2f6817f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -26,7 +26,21 @@ import org.apache.spark.SparkException
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
 import org.apache.spark.rpc.RpcAddress
 
-private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback)
+private[netty] case class OutboxMessage(content: Array[Byte],
+  _onFailure: (Throwable) => Unit,
+  _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+
+  def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
+    override def onFailure(e: Throwable): Unit = {
+      _onFailure(e)
+    }
+
+    override def onSuccess(response: Array[Byte]): Unit = {
+      _onSuccess(client, response)
+    }
+  }
+
+}
 
 private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
 
@@ -68,7 +82,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
       }
     }
     if (dropped) {
-      message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+      message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
     } else {
       drainOutbox()
     }
@@ -108,7 +122,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
       try {
         val _client = synchronized { client }
         if (_client != null) {
-          _client.sendRpc(message.content, message.callback)
+          _client.sendRpc(message.content, message.createCallback(_client))
         } else {
           assert(stopped == true)
         }
@@ -181,7 +195,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
     // update messages and it's safe to just drain the queue.
     var message = messages.poll()
     while (message != null) {
-      message.callback.onFailure(e)
+      message._onFailure(e)
       message = messages.poll()
     }
     assert(messages.isEmpty)
@@ -215,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
     // update messages and it's safe to just drain the queue.
     var message = messages.poll()
     while (message != null) {
-      message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+      message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
       message = messages.poll()
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
index 87b6236..d2e94f9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala
@@ -23,15 +23,25 @@ import org.apache.spark.rpc.RpcAddress
 /**
  * An address identifier for an RPC endpoint.
  *
- * @param host host name of the remote process.
- * @param port the port the remote RPC environment binds to.
- * @param name name of the remote endpoint.
+ * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
+ * connection and can only be reached via the client that sent the endpoint reference.
+ *
+ * @param rpcAddress The socket address of the endpint.
+ * @param name Name of the endpoint.
  */
-private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {
+private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
+
+  require(name != null, "RpcEndpoint name must be provided.")
 
-  def toRpcAddress: RpcAddress = RpcAddress(host, port)
+  def this(host: String, port: Int, name: String) = {
+    this(RpcAddress(host, port), name)
+  }
 
-  override val toString = s"spark://$name@$host:$port"
+  override val toString = if (rpcAddress != null) {
+      s"spark://$name@${rpcAddress.host}:${rpcAddress.port}"
+    } else {
+      s"spark-client://$name"
+    }
 }
 
 private[netty] object RpcEndpointAddress {
@@ -51,7 +61,7 @@ private[netty] object RpcEndpointAddress {
           uri.getQuery != null) {
         throw new SparkException("Invalid Spark URL: " + sparkUrl)
       }
-      RpcEndpointAddress(host, port, name)
+      new RpcEndpointAddress(host, port, name)
     } catch {
       case e: java.net.URISyntaxException =>
         throw new SparkException("Invalid Spark URL: " + sparkUrl, e)

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 8103efa..f3d0d85 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -38,7 +38,7 @@ private[spark] object CoarseGrainedClusterMessages {
 
   sealed trait RegisterExecutorResponse
 
-  case object RegisteredExecutor extends CoarseGrainedClusterMessage
+  case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage
     with RegisterExecutorResponse
 
   case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
@@ -51,9 +51,7 @@ private[spark] object CoarseGrainedClusterMessages {
       hostPort: String,
       cores: Int,
       logUrls: Map[String, String])
-    extends CoarseGrainedClusterMessage {
-    Utils.checkHostPort(hostPort, "Expected host port")
-  }
+    extends CoarseGrainedClusterMessage
 
   case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
     data: SerializableBuffer) extends CoarseGrainedClusterMessage
@@ -107,8 +105,4 @@ private[spark] object CoarseGrainedClusterMessages {
   // Used internally by executors to shut themselves down.
   case object Shutdown extends CoarseGrainedClusterMessage
 
-  // SPARK-10987: workaround for netty RPC issue; forces a connection from the driver back
-  // to the AM.
-  case object DriverHello extends CoarseGrainedClusterMessage
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 55a564b..439a119 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -131,16 +131,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
     override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
 
       case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
-        Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
         if (executorDataMap.contains(executorId)) {
           context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
         } else {
-          logInfo("Registered executor: " + executorRef + " with ID " + executorId)
-          addressToExecutorId(executorRef.address) = executorId
+          // If the executor's rpc env is not listening for incoming connections, `hostPort`
+          // will be null, and the client connection should be used to contact the executor.
+          val executorAddress = if (executorRef.address != null) {
+              executorRef.address
+            } else {
+              context.senderAddress
+            }
+          logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")
+          addressToExecutorId(executorAddress) = executorId
           totalCoreCount.addAndGet(cores)
           totalRegisteredExecutors.addAndGet(1)
-          val (host, _) = Utils.parseHostPort(hostPort)
-          val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls)
+          val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host,
+            cores, cores, logUrls)
           // This must be synchronized because variables mutated
           // in this block are read when requesting executors
           CoarseGrainedSchedulerBackend.this.synchronized {
@@ -151,7 +157,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
             }
           }
           // Note: some tests expect the reply to come after we put the executor in the map
-          context.reply(RegisteredExecutor)
+          context.reply(RegisteredExecutor(executorAddress.host))
           listenerBus.post(
             SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
           makeOffers()

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index e483688..cb24072 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -170,8 +170,6 @@ private[spark] abstract class YarnSchedulerBackend(
       case RegisterClusterManager(am) =>
         logInfo(s"ApplicationMaster registered as $am")
         amEndpoint = Option(am)
-        // See SPARK-10987.
-        am.send(DriverHello)
 
       case AddWebUIFilter(filterName, filterParams, proxyBase) =>
         addWebUIFilter(filterName, filterParams, proxyBase)

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 3bead63..834e474 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -48,7 +48,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
     }
   }
 
-  def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv
+  def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv
 
   test("send a message locally") {
     @volatile var message: String = null
@@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
       }
     })
 
-    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely")
     try {
@@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
       }
     })
 
-    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely")
     try {
@@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
     val shortProp = "spark.rpc.short.timeout"
     conf.set("spark.rpc.retry.wait", "0")
     conf.set("spark.rpc.numRetries", "1")
-    val anotherEnv = createRpcEnv(conf, "remote", 13345)
+    val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")
     try {
@@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
       }
     })
 
-    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely")
     try {
@@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
       }
     })
 
-    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef(
       "local", env.address, "sendWithReply-remotely-error")
@@ -497,26 +497,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
 
     })
 
-    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef(
       "local", env.address, "network-events")
     val remoteAddress = anotherEnv.address
     rpcEndpointRef.send("hello")
     eventually(timeout(5 seconds), interval(5 millis)) {
-      assert(events === List(("onConnected", remoteAddress)))
+      // anotherEnv is connected in client mode, so the remote address may be unknown depending on
+      // the implementation. Account for that when doing checks.
+      if (remoteAddress != null) {
+        assert(events === List(("onConnected", remoteAddress)))
+      } else {
+        assert(events.size === 1)
+        assert(events(0)._1 === "onConnected")
+      }
     }
 
     anotherEnv.shutdown()
     anotherEnv.awaitTermination()
     eventually(timeout(5 seconds), interval(5 millis)) {
-      assert(events === List(
-        ("onConnected", remoteAddress),
-        ("onNetworkError", remoteAddress),
-        ("onDisconnected", remoteAddress)) ||
-        events === List(
-        ("onConnected", remoteAddress),
-        ("onDisconnected", remoteAddress)))
+      // Account for anotherEnv not having an address due to running in client mode.
+      if (remoteAddress != null) {
+        assert(events === List(
+          ("onConnected", remoteAddress),
+          ("onNetworkError", remoteAddress),
+          ("onDisconnected", remoteAddress)) ||
+          events === List(
+          ("onConnected", remoteAddress),
+          ("onDisconnected", remoteAddress)))
+      } else {
+        val eventNames = events.map(_._1)
+        assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") ||
+          eventNames === List("onConnected", "onDisconnected"))
+      }
     }
   }
 
@@ -529,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
       }
     })
 
-    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345)
+    val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true)
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef(
       "local", env.address, "sendWithReply-unserializable-error")
@@ -558,7 +572,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
     conf.set("spark.authenticate.secret", "good")
 
     val localEnv = createRpcEnv(conf, "authentication-local", 13345)
-    val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+    val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
 
     try {
       @volatile var message: String = null
@@ -589,7 +603,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
     conf.set("spark.authenticate.secret", "good")
 
     val localEnv = createRpcEnv(conf, "authentication-local", 13345)
-    val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345)
+    val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true)
 
     try {
       localEnv.setupEndpoint("ask-authentication", new RpcEndpoint {

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
index 4aa75c9..6478ab5 100644
--- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala
@@ -22,9 +22,12 @@ import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf}
 
 class AkkaRpcEnvSuite extends RpcEnvSuite {
 
-  override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
+  override def createRpcEnv(conf: SparkConf,
+      name: String,
+      port: Int,
+      clientMode: Boolean = false): RpcEnv = {
     new AkkaRpcEnvFactory().create(
-      RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf)))
+      RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode))
   }
 
   test("setupEndpointRef: systemName, address, endpointName") {
@@ -37,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
     })
     val conf = new SparkConf()
     val newRpcEnv = new AkkaRpcEnvFactory().create(
-      RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
+      RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false))
     try {
       val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint")
       assert(s"akka.tcp://local@${env.address}/user/test_endpoint" ===
@@ -56,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
     val conf = SSLSampleConfigs.sparkSSLConfig()
     val securityManager = new SecurityManager(conf)
     val rpcEnv = new AkkaRpcEnvFactory().create(
-      RpcEnvConfig(conf, "test", "localhost", 12346, securityManager))
+      RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false))
     try {
       val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint")
       assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri)

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
index 973a07a..56743ba 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala
@@ -22,8 +22,13 @@ import org.apache.spark.SparkFunSuite
 class NettyRpcAddressSuite extends SparkFunSuite {
 
   test("toString") {
-    val addr = RpcEndpointAddress("localhost", 12345, "test")
+    val addr = new RpcEndpointAddress("localhost", 12345, "test")
     assert(addr.toString === "spark://test@localhost:12345")
   }
 
+  test("toString for client mode") {
+    val addr = RpcEndpointAddress(null, "test")
+    assert(addr.toString === "spark-client://test")
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
index be19668..ce83087 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -22,8 +22,13 @@ import org.apache.spark.rpc._
 
 class NettyRpcEnvSuite extends RpcEnvSuite {
 
-  override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
-    val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf))
+  override def createRpcEnv(
+      conf: SparkConf,
+      name: String,
+      port: Int,
+      clientMode: Boolean = false): RpcEnv = {
+    val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf),
+      clientMode)
     new NettyRpcEnvFactory().create(config)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index 5430e4c..f9d8e80 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.rpc._
 class NettyRpcHandlerSuite extends SparkFunSuite {
 
   val env = mock(classOf[NettyRpcEnv])
-  when(env.deserialize(any(classOf[Array[Byte]]))(any())).
+  when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())).
     thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
 
   test("receive") {
@@ -42,7 +42,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
     when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
     nettyRpcHandler.receive(client, null, null)
 
-    verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+    verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
   }
 
   test("connectionTerminated") {
@@ -57,9 +57,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
     when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
     nettyRpcHandler.connectionTerminated(client)
 
-    verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
+    verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000)))
     verify(dispatcher, times(1)).postToAll(
-      RemoteProcessDisconnected(RpcAddress("localhost", 12345)))
+      RemoteProcessDisconnected(RpcAddress("localhost", 40000)))
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/network/yarn/pom.xml
----------------------------------------------------------------------
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index 541ed9a..e2360ef 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -54,6 +54,11 @@
       <groupId>org.apache.hadoop</groupId>
       <artifactId>hadoop-client</artifactId>
     </dependency>
+    <dependency>
+      <groupId>org.slf4j</groupId>
+      <artifactId>slf4j-api</artifactId>
+      <scope>provided</scope>
+    </dependency>
   </dependencies>
 
   <build>

http://git-wip-us.apache.org/repos/asf/spark/blob/71d1c907/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
----------------------------------------------------------------------
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c6a6d7a..12ae350 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -321,7 +321,8 @@ private[spark] class ApplicationMaster(
 
   private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
     val port = sparkConf.getInt("spark.yarn.am.port", 0)
-    rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr)
+    rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr,
+      clientMode = true)
     val driverRef = waitForSparkDriver()
     addAmIpFilter()
     registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
@@ -574,9 +575,6 @@ private[spark] class ApplicationMaster(
       case x: AddWebUIFilter =>
         logInfo(s"Add WebUI Filter. $x")
         driver.send(x)
-
-      case DriverHello =>
-        // SPARK-10987: no action needed for this message.
     }
 
     override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message