spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject [2/3] git commit: [SPARK-2938] Support SASL authentication in NettyBlockTransferService
Date Wed, 05 Nov 2014 22:42:34 GMT
[SPARK-2938] Support SASL authentication in NettyBlockTransferService

Also lays the groundwork for supporting it inside the external shuffle service.

Author: Aaron Davidson <aaron@databricks.com>

Closes #3087 from aarondav/sasl and squashes the following commits:

3481718 [Aaron Davidson] Delete rogue println
44f8410 [Aaron Davidson] Delete documentation - muahaha!
eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level
a6b95f1 [Aaron Davidson] Address comments
785bbde [Aaron Davidson] Cleanup
79973cb [Aaron Davidson] Remove unused file
151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling
f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination
7b42adb [Aaron Davidson] Add unit tests
8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23643403
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23643403
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23643403

Branch: refs/heads/branch-1.2
Commit: 236434033fe452e70dbd0236935a49693712e130
Parents: 9cba88c
Author: Aaron Davidson <aaron@databricks.com>
Authored: Tue Nov 4 16:15:38 2014 -0800
Committer: Patrick Wendell <pwendell@gmail.com>
Committed: Wed Nov 5 14:41:03 2014 -0800

----------------------------------------------------------------------
 .../org/apache/spark/SecurityManager.scala      |  23 ++-
 .../main/scala/org/apache/spark/SparkConf.scala |   6 +
 .../scala/org/apache/spark/SparkContext.scala   |   2 +
 .../main/scala/org/apache/spark/SparkEnv.scala  |   3 +-
 .../org/apache/spark/SparkSaslClient.scala      | 147 ----------------
 .../org/apache/spark/SparkSaslServer.scala      | 176 -------------------
 .../org/apache/spark/executor/Executor.scala    |   1 +
 .../netty/NettyBlockTransferService.scala       |  28 ++-
 .../apache/spark/network/nio/Connection.scala   |   5 +-
 .../spark/network/nio/ConnectionManager.scala   |   7 +-
 .../org/apache/spark/storage/BlockManager.scala |  45 +++--
 .../netty/NettyBlockTransferSecuritySuite.scala | 161 +++++++++++++++++
 .../network/nio/ConnectionManagerSuite.scala    |   6 +-
 .../storage/BlockManagerReplicationSuite.scala  |   2 +
 .../spark/storage/BlockManagerSuite.scala       |   4 +-
 docs/security.md                                |   1 -
 .../apache/spark/network/TransportContext.java  |  15 +-
 .../spark/network/client/TransportClient.java   |  11 +-
 .../client/TransportClientBootstrap.java        |  32 ++++
 .../network/client/TransportClientFactory.java  |  64 +++++--
 .../spark/network/server/NoOpRpcHandler.java    |   2 +-
 .../apache/spark/network/server/RpcHandler.java |  19 +-
 .../network/server/TransportRequestHandler.java |   1 +
 .../spark/network/util/TransportConf.java       |   3 +
 .../spark/network/sasl/SaslClientBootstrap.java |  74 ++++++++
 .../apache/spark/network/sasl/SaslMessage.java  |  74 ++++++++
 .../spark/network/sasl/SaslRpcHandler.java      |  97 ++++++++++
 .../spark/network/sasl/SecretKeyHolder.java     |  35 ++++
 .../spark/network/sasl/SparkSaslClient.java     | 138 +++++++++++++++
 .../spark/network/sasl/SparkSaslServer.java     | 170 ++++++++++++++++++
 .../shuffle/ExternalShuffleBlockHandler.java    |   2 +-
 .../network/shuffle/ExternalShuffleClient.java  |  15 +-
 .../spark/network/shuffle/ShuffleClient.java    |  11 +-
 .../network/sasl/SaslIntegrationSuite.java      | 172 ++++++++++++++++++
 .../spark/network/sasl/SparkSaslSuite.java      |  89 ++++++++++
 .../ExternalShuffleIntegrationSuite.java        |   7 +-
 .../streaming/ReceivedBlockHandlerSuite.scala   |   1 +
 37 files changed, 1257 insertions(+), 392 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/SecurityManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7..dee935f 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
 import org.apache.hadoop.io.Text
 
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
 
 /**
  * Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
  *            Authenticator installed in the SecurityManager to how it does the authentication
  *            and in this case gets the user name and password from the request.
  *
- *  - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ *  - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
  *            exchange messages.  For this we use the Java SASL
  *            (Simple Authentication and Security Layer) API and again use DIGEST-MD5
  *            as the authentication mechanism. This means the shared secret is not passed
@@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
  *            of protection they want. If we support those, the messages will also have to
  *            be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
  *
- *            Since the connectionManager does asynchronous messages passing, the SASL
+ *            Since the NioBlockTransferService does asynchronous messages passing, the SASL
  *            authentication is a bit more complex. A ConnectionManager can be both a client
  *            and a Server, so for a particular connection is has to determine what to do.
  *            A ConnectionId was added to be able to track connections and is used to
@@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
  *            and waits for the response from the server and does the handshake before sending
  *            the real message.
  *
+ *            The NettyBlockTransferService ensures that SASL authentication is performed
+ *            synchronously prior to any other communication on a connection. This is done in
+ *            SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
  *  - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
  *            can be used. Yarn requires a specific AmIpFilter be installed for security to work
  *            properly. For non-Yarn deployments, users can write a filter to go through a
@@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
  *  can take place.
  */
 
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
 
   // key used to store the spark secret in the Hadoop UGI
   private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
    * @return the secret key as a String if authentication is enabled, otherwise returns null
    */
   def getSecretKey(): String = secretKey
+
+  override def getSaslUser(appId: String): String = {
+    val myAppId = sparkConf.getAppId
+    require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
+    getSaslUser()
+  }
+
+  override def getSecretKey(appId: String): String = {
+    val myAppId = sparkConf.getAppId
+    require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
+    getSecretKey()
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/SparkConf.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index ad0a901..4c6c86c 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
      */
     getAll.filter { case (k, _) => isAkkaConf(k) }
 
+  /**
+   * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+   * from the start in the Executor.
+   */
+  def getAppId: String = get("spark.app.id")
+
   /** Does the configuration contain a given parameter? */
   def contains(key: String): Boolean = settings.contains(key)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8b4db78..d65027d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
   val applicationId: String = taskScheduler.applicationId()
   conf.set("spark.app.id", applicationId)
 
+  env.blockManager.initialize(applicationId)
+
   val metricsSystem = env.metricsSystem
 
   // The metrics system for Driver need to be set spark.app.id to app ID.

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e2f13ac..45e9d7f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -276,7 +276,7 @@ object SparkEnv extends Logging {
     val blockTransferService =
       conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
         case "netty" =>
-          new NettyBlockTransferService(conf)
+          new NettyBlockTransferService(conf, securityManager)
         case "nio" =>
           new NioBlockTransferService(conf, securityManager)
       }
@@ -285,6 +285,7 @@ object SparkEnv extends Logging {
       "BlockManagerMaster",
       new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
 
+    // NB: blockManager is not valid until initialize() is called later.
     val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
       serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index a954fcc..0000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager)  extends Logging {
-
-  /**
-   * Used to respond to server's counterpart, SaslServer with SASL tokens
-   * represented as byte arrays.
-   *
-   * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
-   * configurable in the future.
-   */
-  private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
-    null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
-    new SparkSaslClientCallbackHandler(securityMgr))
-
-  /**
-   * Used to initiate SASL handshake with server.
-   * @return response to challenge if needed
-   */
-  def firstToken(): Array[Byte] = {
-    synchronized {
-      val saslToken: Array[Byte] =
-        if (saslClient != null && saslClient.hasInitialResponse()) {
-          logDebug("has initial response")
-          saslClient.evaluateChallenge(new Array[Byte](0))
-        } else {
-          new Array[Byte](0)
-        }
-      saslToken
-    }
-  }
-
-  /**
-   * Determines whether the authentication exchange has completed.
-   * @return true is complete, otherwise false
-   */
-  def isComplete(): Boolean = {
-    synchronized {
-      if (saslClient != null) saslClient.isComplete() else false
-    }
-  }
-
-  /**
-   * Respond to server's SASL token.
-   * @param saslTokenMessage contains server's SASL token
-   * @return client's response SASL token
-   */
-  def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
-    synchronized {
-      if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
-    }
-  }
-
-  /**
-   * Disposes of any system resources or security-sensitive information the
-   * SaslClient might be using.
-   */
-  def dispose() {
-    synchronized {
-      if (saslClient != null) {
-        try {
-          saslClient.dispose()
-        } catch {
-          case e: SaslException => // ignored
-        } finally {
-          saslClient = null
-        }
-      }
-    }
-  }
-
-  /**
-   * Implementation of javax.security.auth.callback.CallbackHandler
-   * that works with share secrets.
-   */
-  private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
-    CallbackHandler {
-
-    private val userName: String =
-      SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-    private val secretKey = securityMgr.getSecretKey()
-    private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
-        if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
-
-    /**
-     * Implementation used to respond to SASL request from the server.
-     *
-     * @param callbacks objects that indicate what credential information the
-     *                  server's SaslServer requires from the client.
-     */
-    override def handle(callbacks: Array[Callback]) {
-      logDebug("in the sasl client callback handler")
-      callbacks foreach {
-        case  nc: NameCallback => {
-          logDebug("handle: SASL client callback: setting username: " + userName)
-          nc.setName(userName)
-        }
-        case pc: PasswordCallback => {
-          logDebug("handle: SASL client callback: setting userPassword")
-          pc.setPassword(userPassword)
-        }
-        case rc: RealmCallback => {
-          logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
-          rc.setText(rc.getDefaultText())
-        }
-        case cb: RealmChoiceCallback => {}
-        case cb: Callback => throw
-          new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index 7c2afb3..0000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
-  /**
-   * Actual SASL work done by this object from javax.security.sasl.
-   */
-  private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
-    SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
-    new SparkSaslDigestCallbackHandler(securityMgr))
-
-  /**
-   * Determines whether the authentication exchange has completed.
-   * @return true is complete, otherwise false
-   */
-  def isComplete(): Boolean = {
-    synchronized {
-      if (saslServer != null) saslServer.isComplete() else false
-    }
-  }
-
-  /**
-   * Used to respond to server SASL tokens.
-   * @param token Server's SASL token
-   * @return response to send back to the server.
-   */
-  def response(token: Array[Byte]): Array[Byte] = {
-    synchronized {
-      if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
-    }
-  }
-
-  /**
-   * Disposes of any system resources or security-sensitive information the
-   * SaslServer might be using.
-   */
-  def dispose() {
-    synchronized {
-      if (saslServer != null) {
-        try {
-          saslServer.dispose()
-        } catch {
-          case e: SaslException => // ignore
-        } finally {
-          saslServer = null
-        }
-      }
-    }
-  }
-
-  /**
-   * Implementation of javax.security.auth.callback.CallbackHandler
-   * for SASL DIGEST-MD5 mechanism
-   */
-  private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
-    extends CallbackHandler {
-
-    private val userName: String =
-      SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-
-    override def handle(callbacks: Array[Callback]) {
-      logDebug("In the sasl server callback handler")
-      callbacks foreach {
-        case nc: NameCallback => {
-          logDebug("handle: SASL server callback: setting username")
-          nc.setName(userName)
-        }
-        case pc: PasswordCallback => {
-          logDebug("handle: SASL server callback: setting userPassword")
-          val password: Array[Char] =
-            SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
-          pc.setPassword(password)
-        }
-        case rc: RealmCallback => {
-          logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
-          rc.setText(rc.getDefaultText())
-        }
-        case ac: AuthorizeCallback => {
-          val authid = ac.getAuthenticationID()
-          val authzid = ac.getAuthorizationID()
-          if (authid.equals(authzid)) {
-            logDebug("set auth to true")
-            ac.setAuthorized(true)
-          } else {
-            logDebug("set auth to false")
-            ac.setAuthorized(false)
-          }
-          if (ac.isAuthorized()) {
-            logDebug("sasl server is authorized")
-            ac.setAuthorizedID(authzid)
-          }
-        }
-        case cb: Callback => throw
-          new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
-      }
-    }
-  }
-}
-
-private[spark] object SparkSaslServer {
-
-  /**
-   * This is passed as the server name when creating the sasl client/server.
-   * This could be changed to be configurable in the future.
-   */
-  val  SASL_DEFAULT_REALM = "default"
-
-  /**
-   * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
-   * configurable in the future.
-   */
-  val DIGEST = "DIGEST-MD5"
-
-  /**
-   * The quality of protection is just "auth". This means that we are doing
-   * authentication only, we are not supporting integrity or privacy protection of the
-   * communication channel after authentication. This could be changed to be configurable
-   * in the future.
-   */
-  val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
-  /**
-   * Encode a byte[] identifier as a Base64-encoded string.
-   *
-   * @param identifier identifier to encode
-   * @return Base64-encoded string
-   */
-  def encodeIdentifier(identifier: Array[Byte]): String = {
-    new String(Base64.encodeBase64(identifier), UTF_8)
-  }
-
-  /**
-   * Encode a password as a base64-encoded char[] array.
-   * @param password as a byte array.
-   * @return password as a char array.
-   */
-  def encodePassword(password: Array[Byte]): Array[Char] = {
-    new String(Base64.encodeBase64(password), UTF_8).toCharArray()
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e24a15f..7dd5265 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -86,6 +86,7 @@ private[spark] class Executor(
         conf, executorId, slaveHostname, port, isLocal, actorSystem)
       SparkEnv.set(_env)
       _env.metricsSystem.registerSource(executorSource)
+      _env.blockManager.initialize(conf.getAppId)
       _env
     } else {
       SparkEnv.get

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1c4327c..0d1fc81 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.network.netty
 
+import scala.collection.JavaConversions._
 import scala.concurrent.{Future, Promise}
 
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
 import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
 import org.apache.spark.network.server._
 import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
 import org.apache.spark.serializer.JavaSerializer
@@ -33,18 +35,30 @@ import org.apache.spark.util.Utils
 /**
  * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
  */
-class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager)
+  extends BlockTransferService {
+
   // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
-  val serializer = new JavaSerializer(conf)
+  private val serializer = new JavaSerializer(conf)
+  private val authEnabled = securityManager.isAuthenticationEnabled()
+  private val transportConf = SparkTransportConf.fromSparkConf(conf)
 
   private[this] var transportContext: TransportContext = _
   private[this] var server: TransportServer = _
   private[this] var clientFactory: TransportClientFactory = _
 
   override def init(blockDataManager: BlockDataManager): Unit = {
-    val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
-    transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
-    clientFactory = transportContext.createClientFactory()
+    val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+      val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+      if (!authEnabled) {
+        (nettyRpcHandler, None)
+      } else {
+        (new SaslRpcHandler(nettyRpcHandler, securityManager),
+          Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+      }
+    }
+    transportContext = new TransportContext(transportConf, rpcHandler)
+    clientFactory = transportContext.createClientFactory(bootstrap.toList)
     server = transportContext.createServer()
     logInfo("Server created on " + server.getPort)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e2..c2d9578 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.LinkedList
 
-import org.apache.spark._
-
 import scala.collection.JavaConversions._
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.util.control.NonFatal
 
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
 private[nio]
 abstract class Connection(val channel: SocketChannel, val selector: Selector,
     val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 8408b75..f198aa8 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -34,6 +34,7 @@ import scala.language.postfixOps
 import com.google.common.base.Charsets.UTF_8
 
 import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
 import org.apache.spark.util.Utils
 
 import scala.util.Try
@@ -600,7 +601,7 @@ private[nio] class ConnectionManager(
     } else {
       var replyToken : Array[Byte] = null
       try {
-        replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+        replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
         if (waitingConn.isSaslComplete()) {
           logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
           connectionsAwaitingSasl -= waitingConn.connectionId
@@ -634,7 +635,7 @@ private[nio] class ConnectionManager(
         connection.synchronized {
           if (connection.sparkSaslServer == null) {
             logDebug("Creating sasl Server")
-            connection.sparkSaslServer = new SparkSaslServer(securityManager)
+            connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
           }
         }
         replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -778,7 +779,7 @@ private[nio] class ConnectionManager(
     if (!conn.isSaslComplete()) {
       conn.synchronized {
         if (conn.sparkSaslClient == null) {
-          conn.sparkSaslClient = new SparkSaslClient(securityManager)
+          conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
           var firstResponse: Array[Byte] = null
           try {
             firstResponse = conn.sparkSaslClient.firstToken()

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 5f5dd0d..655d16c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -57,6 +57,12 @@ private[spark] class BlockResult(
   inputMetrics.bytesRead = bytes
 }
 
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
 private[spark] class BlockManager(
     executorId: String,
     actorSystem: ActorSystem,
@@ -69,8 +75,6 @@ private[spark] class BlockManager(
     blockTransferService: BlockTransferService)
   extends BlockDataManager with Logging {
 
-  blockTransferService.init(this)
-
   val diskBlockManager = new DiskBlockManager(this, conf)
 
   private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -102,22 +106,16 @@ private[spark] class BlockManager(
       + " switch to sort-based shuffle.")
   }
 
-  val blockManagerId = BlockManagerId(
-    executorId, blockTransferService.hostName, blockTransferService.port)
+  var blockManagerId: BlockManagerId = _
 
   // Address of the server that serves this executor's shuffle files. This is either an external
   // service, or just our own Executor's BlockManager.
-  private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
-    BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
-  } else {
-    blockManagerId
-  }
+  private[spark] var shuffleServerId: BlockManagerId = _
 
   // Client to read other executors' shuffle files. This is either an external service, or just the
   // standard BlockTranserService to directly connect to other Executors.
   private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
-    val appId = conf.get("spark.app.id", "unknown-app-id")
-    new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+    new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf))
   } else {
     blockTransferService
   }
@@ -150,8 +148,6 @@ private[spark] class BlockManager(
   private val peerFetchLock = new Object
   private var lastPeerFetchTime = 0L
 
-  initialize()
-
   /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
    * the initialization of the compression codec until it is first used. The reason is that a Spark
    * program could be using a user-defined codec in a third party jar, which is loaded in
@@ -176,10 +172,27 @@ private[spark] class BlockManager(
   }
 
   /**
-   * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
-   * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
+   * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+   * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+   * where it is only learned after registration with the TaskScheduler).
+   *
+   * This method initializes the BlockTransferService and ShuffleClient, registers with the
+   * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+   * service if configured.
    */
-  private def initialize(): Unit = {
+  def initialize(appId: String): Unit = {
+    blockTransferService.init(this)
+    shuffleClient.init(appId)
+
+    blockManagerId = BlockManagerId(
+      executorId, blockTransferService.hostName, blockTransferService.port)
+
+    shuffleServerId = if (externalShuffleServiceEnabled) {
+      BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+    } else {
+      blockManagerId
+    }
+
     master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
 
     // Register Executors' configuration with the local shuffle service, if one should exist.

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
new file mode 100644
index 0000000..bed0ed9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio._
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.util.{Failure, Success, Try}
+
+import org.apache.commons.io.IOUtils
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.{BlockDataManager, BlockTransferService}
+import org.apache.spark.storage.{BlockId, ShuffleBlockId}
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.mockito.Mockito._
+import org.scalatest.mock.MockitoSugar
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers}
+
+class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
+  test("security default off") {
+    testConnection(new SparkConf, new SparkConf) match {
+      case Success(_) => // expected
+      case Failure(t) => fail(t)
+    }
+  }
+
+  test("security on same password") {
+    val conf = new SparkConf()
+      .set("spark.authenticate", "true")
+      .set("spark.authenticate.secret", "good")
+      .set("spark.app.id", "app-id")
+    testConnection(conf, conf) match {
+      case Success(_) => // expected
+      case Failure(t) => fail(t)
+    }
+  }
+
+  test("security on mismatch password") {
+    val conf0 = new SparkConf()
+      .set("spark.authenticate", "true")
+      .set("spark.authenticate.secret", "good")
+      .set("spark.app.id", "app-id")
+    val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
+    testConnection(conf0, conf1) match {
+      case Success(_) => fail("Should have failed")
+      case Failure(t) => t.getMessage should include ("Mismatched response")
+    }
+  }
+
+  test("security mismatch auth off on server") {
+    val conf0 = new SparkConf()
+      .set("spark.authenticate", "true")
+      .set("spark.authenticate.secret", "good")
+      .set("spark.app.id", "app-id")
+    val conf1 = conf0.clone.set("spark.authenticate", "false")
+    testConnection(conf0, conf1) match {
+      case Success(_) => fail("Should have failed")
+      case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
+    }
+  }
+
+  test("security mismatch auth off on client") {
+    val conf0 = new SparkConf()
+      .set("spark.authenticate", "false")
+      .set("spark.authenticate.secret", "good")
+      .set("spark.app.id", "app-id")
+    val conf1 = conf0.clone.set("spark.authenticate", "true")
+    testConnection(conf0, conf1) match {
+      case Success(_) => fail("Should have failed")
+      case Failure(t) => t.getMessage should include ("Expected SaslMessage")
+    }
+  }
+
+  test("security mismatch app ids") {
+    val conf0 = new SparkConf()
+      .set("spark.authenticate", "true")
+      .set("spark.authenticate.secret", "good")
+      .set("spark.app.id", "app-id")
+    val conf1 = conf0.clone.set("spark.app.id", "other-id")
+    testConnection(conf0, conf1) match {
+      case Success(_) => fail("Should have failed")
+      case Failure(t) => t.getMessage should include ("SASL appId app-id did not match")
+    }
+  }
+
+  /**
+   * Creates two servers with different configurations and sees if they can talk.
+   * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
+   * properly. We will throw an out-of-band exception if something other than that goes wrong.
+   */
+  private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = {
+    val blockManager = mock[BlockDataManager]
+    val blockId = ShuffleBlockId(0, 1, 2)
+    val blockString = "Hello, world!"
+    val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes))
+    when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
+
+    val securityManager0 = new SecurityManager(conf0)
+    val exec0 = new NettyBlockTransferService(conf0, securityManager0)
+    exec0.init(blockManager)
+
+    val securityManager1 = new SecurityManager(conf1)
+    val exec1 = new NettyBlockTransferService(conf1, securityManager1)
+    exec1.init(blockManager)
+
+    val result = fetchBlock(exec0, exec1, "1", blockId) match {
+      case Success(buf) =>
+        IOUtils.toString(buf.createInputStream()) should equal(blockString)
+        buf.release()
+        Success()
+      case Failure(t) =>
+        Failure(t)
+    }
+    exec0.close()
+    exec1.close()
+    result
+  }
+
+  /** Synchronously fetches a single block, acting as the given executor fetching from another. */
+  private def fetchBlock(
+      self: BlockTransferService,
+      from: BlockTransferService,
+      execId: String,
+      blockId: BlockId): Try[ManagedBuffer] = {
+
+    val promise = Promise[ManagedBuffer]()
+
+    self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
+      new BlockFetchingListener {
+        override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+          promise.failure(exception)
+        }
+
+        override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+          promise.success(data.retain())
+        }
+      })
+
+    Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS))
+    promise.future.value.get
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index b70734d..716f875 100644
--- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite {
     val conf = new SparkConf
     conf.set("spark.authenticate", "true")
     conf.set("spark.authenticate.secret", "good")
+    conf.set("spark.app.id", "app-id")
     val securityManager = new SecurityManager(conf)
     val manager = new ConnectionManager(0, conf, securityManager)
     var numReceivedMessages = 0
@@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite {
   test("security mismatch password") {
     val conf = new SparkConf
     conf.set("spark.authenticate", "true")
+    conf.set("spark.app.id", "app-id")
     conf.set("spark.authenticate.secret", "good")
     val securityManager = new SecurityManager(conf)
     val manager = new ConnectionManager(0, conf, securityManager)
@@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite {
       None
     })
 
-    val badconf = new SparkConf
-    badconf.set("spark.authenticate", "true")
-    badconf.set("spark.authenticate.secret", "bad")
+    val badconf = conf.clone.set("spark.authenticate.secret", "bad")
     val badsecurityManager = new SecurityManager(badconf)
     val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
     var numReceivedServerMessages = 0

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index c6d7105..1461fa6 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
     val transfer = new NioBlockTransferService(conf, securityMgr)
     val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
       mapOutputTracker, shuffleManager, transfer)
+    store.initialize("app-id")
     allStores += store
     store
   }
@@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
     when(failableTransfer.port).thenReturn(1000)
     val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
       10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
+    failableStore.initialize("app-id")
     allStores += failableStore // so that this gets stopped after test
     assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 715b740..0782876 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NioBlockTransferService(conf, securityMgr)
-    new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+    val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
       mapOutputTracker, shuffleManager, transfer)
+    manager.initialize("app-id")
+    manager
   }
 
   before {

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/docs/security.md
----------------------------------------------------------------------
diff --git a/docs/security.md b/docs/security.md
index ec05231..1e206a1 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can
 
 * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. 
 * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications.
-* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.*
 
 ## Web UI
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/TransportContext.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index a271841..5bc6e5a 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,12 +17,16 @@
 
 package org.apache.spark.network;
 
+import java.util.List;
+
+import com.google.common.collect.Lists;
 import io.netty.channel.Channel;
 import io.netty.channel.socket.SocketChannel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
 import org.apache.spark.network.client.TransportClientFactory;
 import org.apache.spark.network.client.TransportResponseHandler;
 import org.apache.spark.network.protocol.MessageDecoder;
@@ -64,8 +68,17 @@ public class TransportContext {
     this.decoder = new MessageDecoder();
   }
 
+  /**
+   * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
+   * a new Client. Bootstraps will be executed synchronously, and must run successfully in order
+   * to create a Client.
+   */
+  public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {
+    return new TransportClientFactory(this, bootstraps);
+  }
+
   public TransportClientFactory createClientFactory() {
-    return new TransportClientFactory(this);
+    return createClientFactory(Lists.<TransportClientBootstrap>newArrayList());
   }
 
   /** Create a server which will attempt to bind to a specific port. */

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 01c143f..a08cee0 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,10 +19,9 @@ package org.apache.spark.network.client;
 
 import java.io.Closeable;
 import java.util.UUID;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
 
+import com.google.common.base.Objects;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Throwables;
 import com.google.common.util.concurrent.SettableFuture;
@@ -186,4 +185,12 @@ public class TransportClient implements Closeable {
     // close is a local operation and should finish with milliseconds; timeout just to be safe
     channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
   }
+
+  @Override
+  public String toString() {
+    return Objects.toStringHelper(this)
+      .add("remoteAdress", channel.remoteAddress())
+      .add("isActive", isActive())
+      .toString();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
new file mode 100644
index 0000000..65e8020
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to the user.
+ * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
+ * connection basis.
+ *
+ * Since connections (and TransportClients) are reused as much as possible, it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+  /** Performs the bootstrapping operation, throwing an exception on failure. */
+  public void doBootstrap(TransportClient client) throws RuntimeException;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 0b4a1d8..1723fed 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -21,10 +21,14 @@ import java.io.Closeable;
 import java.lang.reflect.Field;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
+import java.util.List;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicReference;
 
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
 import io.netty.bootstrap.Bootstrap;
 import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.channel.Channel;
@@ -40,6 +44,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.network.TransportContext;
 import org.apache.spark.network.server.TransportChannelHandler;
 import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -47,22 +52,29 @@ import org.apache.spark.network.util.TransportConf;
  * Factory for creating {@link TransportClient}s by using createClient.
  *
  * The factory maintains a connection pool to other hosts and should return the same
- * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
- * all {@link TransportClient}s.
+ * TransportClient for the same remote host. It also shares a single worker thread pool for
+ * all TransportClients.
+ *
+ * TransportClients will be reused whenever possible. Prior to completing the creation of a new
+ * TransportClient, all given {@link TransportClientBootstrap}s will be run.
  */
 public class TransportClientFactory implements Closeable {
   private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
 
   private final TransportContext context;
   private final TransportConf conf;
+  private final List<TransportClientBootstrap> clientBootstraps;
   private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
 
   private final Class<? extends Channel> socketChannelClass;
   private EventLoopGroup workerGroup;
 
-  public TransportClientFactory(TransportContext context) {
-    this.context = context;
+  public TransportClientFactory(
+      TransportContext context,
+      List<TransportClientBootstrap> clientBootstraps) {
+    this.context = Preconditions.checkNotNull(context);
     this.conf = context.getConf();
+    this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
     this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
 
     IOMode ioMode = IOMode.valueOf(conf.ioMode());
@@ -72,9 +84,12 @@ public class TransportClientFactory implements Closeable {
   }
 
   /**
-   * Create a new BlockFetchingClient connecting to the given remote host / port.
+   * Create a new {@link TransportClient} connecting to the given remote host / port. This will
+   * reuse TransportClients if they are still active and are for the same remote address. Prior
+   * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
+   * that are registered with this factory.
    *
-   * This blocks until a connection is successfully established.
+   * This blocks until a connection is successfully established and fully bootstrapped.
    *
    * Concurrency: This method is safe to call from multiple threads.
    */
@@ -104,17 +119,18 @@ public class TransportClientFactory implements Closeable {
     // Use pooled buffers to reduce temporary buffer allocation
     bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
 
-    final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();
+    final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
 
     bootstrap.handler(new ChannelInitializer<SocketChannel>() {
       @Override
       public void initChannel(SocketChannel ch) {
         TransportChannelHandler clientHandler = context.initializePipeline(ch);
-        client.set(clientHandler.getClient());
+        clientRef.set(clientHandler.getClient());
       }
     });
 
     // Connect to the remote server
+    long preConnect = System.currentTimeMillis();
     ChannelFuture cf = bootstrap.connect(address);
     if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
       throw new RuntimeException(
@@ -123,15 +139,35 @@ public class TransportClientFactory implements Closeable {
       throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
     }
 
-    // Successful connection -- in the event that two threads raced to create a client, we will
+    TransportClient client = clientRef.get();
+    assert client != null : "Channel future completed successfully with null client";
+
+    // Execute any client bootstraps synchronously before marking the Client as successful.
+    long preBootstrap = System.currentTimeMillis();
+    logger.debug("Connection to {} successful, running bootstraps...", address);
+    try {
+      for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+        clientBootstrap.doBootstrap(client);
+      }
+    } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
+      long bootstrapTime = System.currentTimeMillis() - preBootstrap;
+      logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+      client.close();
+      throw Throwables.propagate(e);
+    }
+    long postBootstrap = System.currentTimeMillis();
+
+    // Successful connection & bootstrap -- in the event that two threads raced to create a client,
     // use the first one that was put into the connectionPool and close the one we made here.
-    assert client.get() != null : "Channel future completed successfully with null client";
-    TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+    TransportClient oldClient = connectionPool.putIfAbsent(address, client);
     if (oldClient == null) {
-      return client.get();
+      logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+        address, postBootstrap - preConnect, postBootstrap - preBootstrap);
+      return client;
     } else {
-      logger.debug("Two clients were created concurrently, second one will be disposed.");
-      client.get().close();
+      logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
+        postBootstrap - preConnect);
+      client.close();
       return oldClient;
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 5a3f003..1502b74 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -21,7 +21,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 
 /** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
-public class NoOpRpcHandler implements RpcHandler {
+public class NoOpRpcHandler extends RpcHandler {
   private final StreamManager streamManager;
 
   public NoOpRpcHandler() {

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 2369dc6..2ba92a4 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -23,22 +23,33 @@ import org.apache.spark.network.client.TransportClient;
 /**
  * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
  */
-public interface RpcHandler {
+public abstract class RpcHandler {
   /**
    * Receive a single RPC message. Any exception thrown while in this method will be sent back to
    * the client in string form as a standard RPC failure.
    *
+   * This method will not be called in parallel for a single TransportClient (i.e., channel).
+   *
    * @param client A channel client which enables the handler to make requests back to the sender
-   *               of this RPC.
+   *               of this RPC. This will always be the exact same object for a particular channel.
    * @param message The serialized bytes of the RPC.
    * @param callback Callback which should be invoked exactly once upon success or failure of the
    *                 RPC.
    */
-  void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+  public abstract void receive(
+      TransportClient client,
+      byte[] message,
+      RpcResponseCallback callback);
 
   /**
    * Returns the StreamManager which contains the state about which streams are currently being
    * fetched by a TransportClient.
    */
-  StreamManager getStreamManager();
+  public abstract StreamManager getStreamManager();
+
+  /**
+   * Invoked when the connection associated with the given client has been invalidated.
+   * No further requests will come from this client.
+   */
+  public void connectionTerminated(TransportClient client) { }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 17fe900..1580180 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -86,6 +86,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
     for (long streamId : streamIds) {
       streamManager.connectionTerminated(streamId);
     }
+    rpcHandler.connectionTerminated(reverseClient);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index a68f38e..823790d 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -55,4 +55,7 @@ public class TransportConf {
 
   /** Send buffer size (SO_SNDBUF). */
   public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
+
+  /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
+  public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000..7bc91e3
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+  private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+  private final TransportConf conf;
+  private final String appId;
+  private final SecretKeyHolder secretKeyHolder;
+
+  public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+    this.conf = conf;
+    this.appId = appId;
+    this.secretKeyHolder = secretKeyHolder;
+  }
+
+  /**
+   * Performs SASL authentication by sending a token, and then proceeding with the SASL
+   * challenge-response tokens until we either successfully authenticate or throw an exception
+   * due to mismatch.
+   */
+  @Override
+  public void doBootstrap(TransportClient client) {
+    SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+    try {
+      byte[] payload = saslClient.firstToken();
+
+      while (!saslClient.isComplete()) {
+        SaslMessage msg = new SaslMessage(appId, payload);
+        ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+        msg.encode(buf);
+
+        byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout());
+        payload = saslClient.response(response);
+      }
+    } finally {
+      try {
+        // Once authentication is complete, the server will trust all remaining communication.
+        saslClient.dispose();
+      } catch (RuntimeException e) {
+        logger.error("Error while disposing SASL client", e);
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000..5b77e18
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage implements Encodable {
+
+  /** Serialization tag used to catch incorrect payloads. */
+  private static final byte TAG_BYTE = (byte) 0xEA;
+
+  public final String appId;
+  public final byte[] payload;
+
+  public SaslMessage(String appId, byte[] payload) {
+    this.appId = appId;
+    this.payload = payload;
+  }
+
+  @Override
+  public int encodedLength() {
+    // tag + appIdLength + appId + payloadLength + payload
+    return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeByte(TAG_BYTE);
+    byte[] idBytes = appId.getBytes(Charsets.UTF_8);
+    buf.writeInt(idBytes.length);
+    buf.writeBytes(idBytes);
+    buf.writeInt(payload.length);
+    buf.writeBytes(payload);
+  }
+
+  public static SaslMessage decode(ByteBuf buf) {
+    if (buf.readByte() != TAG_BYTE) {
+      throw new IllegalStateException("Expected SaslMessage, received something else");
+    }
+
+    int idLength = buf.readInt();
+    byte[] idBytes = new byte[idLength];
+    buf.readBytes(idBytes);
+
+    int payloadLength = buf.readInt();
+    byte[] payload = new byte[payloadLength];
+    buf.readBytes(payload);
+
+    return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000..3777a18
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.concurrent.ConcurrentMap;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends RpcHandler {
+  private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+  /** RpcHandler we will delegate to for authenticated connections. */
+  private final RpcHandler delegate;
+
+  /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+  private final SecretKeyHolder secretKeyHolder;
+
+  /** Maps each channel to its SASL authentication state. */
+  private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+
+  public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+    this.delegate = delegate;
+    this.secretKeyHolder = secretKeyHolder;
+    this.channelAuthenticationMap = Maps.newConcurrentMap();
+  }
+
+  @Override
+  public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+    SparkSaslServer saslServer = channelAuthenticationMap.get(client);
+    if (saslServer != null && saslServer.isComplete()) {
+      // Authentication complete, delegate to base handler.
+      delegate.receive(client, message, callback);
+      return;
+    }
+
+    SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+
+    if (saslServer == null) {
+      // First message in the handshake, setup the necessary state.
+      saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
+      channelAuthenticationMap.put(client, saslServer);
+    }
+
+    byte[] response = saslServer.response(saslMessage.payload);
+    if (saslServer.isComplete()) {
+      logger.debug("SASL authentication successful for channel {}", client);
+    }
+    callback.onSuccess(response);
+  }
+
+  @Override
+  public StreamManager getStreamManager() {
+    return delegate.getStreamManager();
+  }
+
+  @Override
+  public void connectionTerminated(TransportClient client) {
+    SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
+    if (saslServer != null) {
+      saslServer.dispose();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000..81d5766
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+  /**
+   * Gets an appropriate SASL User for the given appId.
+   * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+   */
+  String getSaslUser(String appId);
+
+  /**
+   * Gets an appropriate SASL secret key for the given appId.
+   * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+   */
+  String getSecretKey(String appId);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/23643403/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000..72ba737
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+
+import com.google.common.base.Throwables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient {
+  private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+  private final String secretKeyId;
+  private final SecretKeyHolder secretKeyHolder;
+  private SaslClient saslClient;
+
+  public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+    this.secretKeyId = secretKeyId;
+    this.secretKeyHolder = secretKeyHolder;
+    try {
+      this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+        SASL_PROPS, new ClientCallbackHandler());
+    } catch (SaslException e) {
+      throw Throwables.propagate(e);
+    }
+  }
+
+  /** Used to initiate SASL handshake with server. */
+  public synchronized byte[] firstToken() {
+    if (saslClient != null && saslClient.hasInitialResponse()) {
+      try {
+        return saslClient.evaluateChallenge(new byte[0]);
+      } catch (SaslException e) {
+        throw Throwables.propagate(e);
+      }
+    } else {
+      return new byte[0];
+    }
+  }
+
+  /** Determines whether the authentication exchange has completed. */
+  public synchronized boolean isComplete() {
+    return saslClient != null && saslClient.isComplete();
+  }
+
+  /**
+   * Respond to server's SASL token.
+   * @param token contains server's SASL token
+   * @return client's response SASL token
+   */
+  public synchronized byte[] response(byte[] token) {
+    try {
+      return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+    } catch (SaslException e) {
+      throw Throwables.propagate(e);
+    }
+  }
+
+  /**
+   * Disposes of any system resources or security-sensitive information the
+   * SaslClient might be using.
+   */
+  public synchronized void dispose() {
+    if (saslClient != null) {
+      try {
+        saslClient.dispose();
+      } catch (SaslException e) {
+        // ignore
+      } finally {
+        saslClient = null;
+      }
+    }
+  }
+
+  /**
+   * Implementation of javax.security.auth.callback.CallbackHandler
+   * that works with share secrets.
+   */
+  private class ClientCallbackHandler implements CallbackHandler {
+    @Override
+    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+      for (Callback callback : callbacks) {
+        if (callback instanceof NameCallback) {
+          logger.trace("SASL client callback: setting username");
+          NameCallback nc = (NameCallback) callback;
+          nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+        } else if (callback instanceof PasswordCallback) {
+          logger.trace("SASL client callback: setting password");
+          PasswordCallback pc = (PasswordCallback) callback;
+          pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+        } else if (callback instanceof RealmCallback) {
+          logger.trace("SASL client callback: setting realm");
+          RealmCallback rc = (RealmCallback) callback;
+          rc.setText(rc.getDefaultText());
+          logger.info("Realm callback");
+        } else if (callback instanceof RealmChoiceCallback) {
+          // ignore (?)
+        } else {
+          throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+        }
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message