spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [2/2] spark git commit: [SPARK-6229] Add SASL encryption to network library.
Date Sat, 02 May 2015 02:02:25 GMT
[SPARK-6229] Add SASL encryption to network library.

There are two main parts of this change:

- Extending the bootstrap mechanism in the network library to add a server-side
  bootstrap (which works a little bit differently than the client-side bootstrap), and
  to allow the  bootstraps to modify the underlying channel.

- Use SASL to encrypt data going through the RPC channel.

The second item requires some non-optimal code to be able to work around the
fact that the outbound path in netty is not thread-safe, and ordering is very important
when encryption is in the picture.

A lot of the changes outside the network/common library are just to adjust to the
changed API for initializing the RPC server.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #5377 from vanzin/SPARK-6229 and squashes the following commits:

ff01966 [Marcelo Vanzin] Use fancy new size config style.
be53f32 [Marcelo Vanzin] Merge branch 'master' into SPARK-6229
47d4aff [Marcelo Vanzin] Merge branch 'master' into SPARK-6229
7a2a805 [Marcelo Vanzin] Clean up some unneeded changes.
2f92237 [Marcelo Vanzin] Add comment.
67bb0c6 [Marcelo Vanzin] Revert "Avoid exposing ByteArrayWritableChannel outside of test code."
065f684 [Marcelo Vanzin] Add test to verify chunking.
3d1695d [Marcelo Vanzin] Minor cleanups.
73cff0e [Marcelo Vanzin] Skip bytes in decode path too.
318ad23 [Marcelo Vanzin] Avoid exposing ByteArrayWritableChannel outside of test code.
346f829 [Marcelo Vanzin] Avoid trip through channel selector by not reporting 0 bytes written.
a4a5938 [Marcelo Vanzin] Review feedback.
4797519 [Marcelo Vanzin] Remove unused import.
9908ada [Marcelo Vanzin] Fix test, SASL backend disposal.
7fe1489 [Marcelo Vanzin] Add a test that makes sure encryption is actually enabled.
adb6f9d [Marcelo Vanzin] Review feedback.
cf2a605 [Marcelo Vanzin] Clean up some code.
8584323 [Marcelo Vanzin] Fix a comment.
e98bc55 [Marcelo Vanzin] Add option to only allow encrypted connections to the server.
dad42fc [Marcelo Vanzin] Make encryption thread-safe, less memory-intensive.
b00999a [Marcelo Vanzin] Consolidate ByteArrayWritableChannel, fix SASL code to match master changes.
b923cae [Marcelo Vanzin] Make SASL encryption handler thread-safe, handle FileRegion messages.
39539a7 [Marcelo Vanzin] Add config option to enable SASL encryption.
351a86f [Marcelo Vanzin] Add SASL encryption to network library.
fbe6ccb [Marcelo Vanzin] Add TransportServerBootstrap, make SASL code use it.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/38d4e9e4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/38d4e9e4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/38d4e9e4

Branch: refs/heads/master
Commit: 38d4e9e446b425ca6a8fe8d8080f387b08683842
Parents: 8f50a07
Author: Marcelo Vanzin <vanzin@cloudera.com>
Authored: Fri May 1 19:01:46 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Fri May 1 19:01:46 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/SecurityManager.scala      |  17 +-
 .../spark/deploy/ExternalShuffleService.scala   |  17 +-
 .../netty/NettyBlockTransferService.scala       |  22 +-
 .../spark/network/nio/ConnectionManager.scala   |   4 +-
 .../org/apache/spark/storage/BlockManager.scala |   3 +-
 .../apache/spark/network/TransportContext.java  |  26 +-
 .../client/TransportClientBootstrap.java        |   4 +-
 .../network/client/TransportClientFactory.java  |   5 +-
 .../spark/network/sasl/SaslClientBootstrap.java |  41 ++-
 .../spark/network/sasl/SaslEncryption.java      | 291 +++++++++++++++
 .../network/sasl/SaslEncryptionBackend.java     |  33 ++
 .../spark/network/sasl/SaslRpcHandler.java      |  56 ++-
 .../spark/network/sasl/SaslServerBootstrap.java |  49 +++
 .../spark/network/sasl/SparkSaslClient.java     |  33 +-
 .../spark/network/sasl/SparkSaslServer.java     |  49 ++-
 .../spark/network/server/TransportServer.java   |  19 +-
 .../server/TransportServerBootstrap.java        |  36 ++
 .../network/util/ByteArrayWritableChannel.java  |  69 ++++
 .../spark/network/util/TransportConf.java       |  18 +
 .../spark/network/ByteArrayWritableChannel.java |  55 ---
 .../org/apache/spark/network/ProtocolSuite.java |   1 +
 .../protocol/MessageWithHeaderSuite.java        |   2 +-
 .../spark/network/sasl/SparkSaslSuite.java      | 358 ++++++++++++++++++-
 .../network/shuffle/ExternalShuffleClient.java  |  11 +-
 .../network/sasl/SaslIntegrationSuite.java      |   9 +-
 .../ExternalShuffleIntegrationSuite.java        |   4 +-
 .../shuffle/ExternalShuffleSecuritySuite.java   |  27 +-
 .../spark/network/yarn/YarnShuffleService.java  |  15 +-
 28 files changed, 1119 insertions(+), 155 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/core/src/main/scala/org/apache/spark/SecurityManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 3653f72..8aed1e2 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -150,8 +150,13 @@ import org.apache.spark.util.Utils
  *  authorization. If not filter is in place the user is generally null and no authorization
  *  can take place.
  *
- *  Connection encryption (SSL) configuration is organized hierarchically. The user can configure
- *  the default SSL settings which will be used for all the supported communication protocols unless
+ *  When authentication is being used, encryption can also be enabled by setting the option
+ *  spark.authenticate.enableSaslEncryption to true. This is only supported by communication
+ *  channels that use the network-common library, and can be used as an alternative to SSL in those
+ *  cases.
+ *
+ *  SSL can be used for encryption for certain communication channels. The user can configure the
+ *  default SSL settings which will be used for all the supported communication protocols unless
  *  they are overwritten by protocol specific settings. This way the user can easily provide the
  *  common settings for all the protocols without disabling the ability to configure each one
  *  individually.
@@ -413,6 +418,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
   def isAuthenticationEnabled(): Boolean = authOn
 
   /**
+   * Checks whether SASL encryption should be enabled.
+   * @return Whether to enable SASL encryption when connecting to services that support it.
+   */
+  def isSaslEncryptionEnabled(): Boolean = {
+    sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false)
+  }
+
+  /**
    * Gets the user used for authenticating HTTP connections.
    * For now use a single hardcoded user.
    * @return the HTTP user as a String

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index cd16f99..09973a0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -19,10 +19,12 @@ package org.apache.spark.deploy
 
 import java.util.concurrent.CountDownLatch
 
+import scala.collection.JavaConversions._
+
 import org.apache.spark.{Logging, SparkConf, SecurityManager}
 import org.apache.spark.network.TransportContext
 import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.sasl.SaslServerBootstrap
 import org.apache.spark.network.server.TransportServer
 import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
 import org.apache.spark.util.Utils
@@ -44,10 +46,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
 
   private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
   private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
-  private val transportContext: TransportContext = {
-    val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
-    new TransportContext(transportConf, handler)
-  }
+  private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)
 
   private var server: TransportServer = _
 
@@ -62,7 +61,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
   def start() {
     require(server == null, "Shuffle server already started")
     logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
-    server = transportContext.createServer(port)
+    val bootstraps =
+      if (useSasl) {
+        Seq(new SaslServerBootstrap(transportConf, securityManager))
+      } else {
+        Nil
+      }
+    server = transportContext.createServer(port, bootstraps)
   }
 
   def stop() {

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 3f0950d..6181c0e 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
-import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
+import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
 import org.apache.spark.network.server._
 import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
 import org.apache.spark.network.shuffle.protocol.UploadBlock
@@ -49,18 +49,18 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
   private[this] var appId: String = _
 
   override def init(blockDataManager: BlockDataManager): Unit = {
-    val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
-      val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
-      if (!authEnabled) {
-        (nettyRpcHandler, None)
-      } else {
-        (new SaslRpcHandler(nettyRpcHandler, securityManager),
-          Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
-      }
+    val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+    var serverBootstrap: Option[TransportServerBootstrap] = None
+    var clientBootstrap: Option[TransportClientBootstrap] = None
+    if (authEnabled) {
+      serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
+      clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
+        securityManager.isSaslEncryptionEnabled()))
     }
     transportContext = new TransportContext(transportConf, rpcHandler)
-    clientFactory = transportContext.createClientFactory(bootstrap.toList)
-    server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
+    clientFactory = transportContext.createClientFactory(clientBootstrap.toList)
+    server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0),
+      serverBootstrap.toList)
     appId = conf.getAppId
     logInfo("Server created on " + server.getPort)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 16e9059..497871e 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -656,7 +656,7 @@ private[nio] class ConnectionManager(
         connection.synchronized {
           if (connection.sparkSaslServer == null) {
             logDebug("Creating sasl Server")
-            connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
+            connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false)
           }
         }
         replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -800,7 +800,7 @@ private[nio] class ConnectionManager(
     if (!conn.isSaslComplete()) {
       conn.synchronized {
         if (conn.sparkSaslClient == null) {
-          conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
+          conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false)
           var firstResponse: Array[Byte] = null
           try {
             firstResponse = conn.sparkSaslClient.firstToken()

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 402ee1c..a46fecd 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -111,7 +111,8 @@ private[spark] class BlockManager(
   // standard BlockTransferService to directly connect to other Executors.
   private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
     val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
-    new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
+    new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
+      securityManager.isSaslEncryptionEnabled())
   } else {
     blockTransferService
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/TransportContext.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 3fe69b1..b8d073f 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -36,6 +36,7 @@ import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.TransportChannelHandler;
 import org.apache.spark.network.server.TransportRequestHandler;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.TransportConf;
 
@@ -82,13 +83,21 @@ public class TransportContext {
   }
 
   /** Create a server which will attempt to bind to a specific port. */
-  public TransportServer createServer(int port) {
-    return new TransportServer(this, port);
+  public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
+    return new TransportServer(this, port, rpcHandler, bootstraps);
   }
 
   /** Creates a new server, binding to any available ephemeral port. */
+  public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {
+    return createServer(0, bootstraps);
+  }
+
   public TransportServer createServer() {
-    return new TransportServer(this, 0);
+    return createServer(0, Lists.<TransportServerBootstrap>newArrayList());
+  }
+
+  public TransportChannelHandler initializePipeline(SocketChannel channel) {
+    return initializePipeline(channel, rpcHandler);
   }
 
   /**
@@ -96,13 +105,18 @@ public class TransportContext {
    * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
    * response messages.
    *
+   * @param channel The channel to initialize.
+   * @param channelRpcHandler The RPC handler to use for the channel.
+   *
    * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
    * be used to communicate on this channel. The TransportClient is directly associated with a
    * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
    */
-  public TransportChannelHandler initializePipeline(SocketChannel channel) {
+  public TransportChannelHandler initializePipeline(
+      SocketChannel channel,
+      RpcHandler channelRpcHandler) {
     try {
-      TransportChannelHandler channelHandler = createChannelHandler(channel);
+      TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
       channel.pipeline()
         .addLast("encoder", encoder)
         .addLast("frameDecoder", NettyUtils.createFrameDecoder())
@@ -123,7 +137,7 @@ public class TransportContext {
    * ResponseMessages. The channel is expected to have been successfully created, though certain
    * properties (such as the remoteAddress()) may not be available yet.
    */
-  private TransportChannelHandler createChannelHandler(Channel channel) {
+  private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
     TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
     TransportClient client = new TransportClient(channel, responseHandler);
     TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
index 65e8020..eaae2ee 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.client;
 
+import io.netty.channel.Channel;
+
 /**
  * A bootstrap which is executed on a TransportClient before it is returned to the user.
  * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
@@ -28,5 +30,5 @@ package org.apache.spark.network.client;
  */
 public interface TransportClientBootstrap {
   /** Performs the bootstrapping operation, throwing an exception on failure. */
-  public void doBootstrap(TransportClient client) throws RuntimeException;
+  void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index d26b9b4..4952ffb 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -172,12 +172,14 @@ public class TransportClientFactory implements Closeable {
       .option(ChannelOption.ALLOCATOR, pooledAllocator);
 
     final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
+    final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();
 
     bootstrap.handler(new ChannelInitializer<SocketChannel>() {
       @Override
       public void initChannel(SocketChannel ch) {
         TransportChannelHandler clientHandler = context.initializePipeline(ch);
         clientRef.set(clientHandler.getClient());
+        channelRef.set(ch);
       }
     });
 
@@ -192,6 +194,7 @@ public class TransportClientFactory implements Closeable {
     }
 
     TransportClient client = clientRef.get();
+    Channel channel = channelRef.get();
     assert client != null : "Channel future completed successfully with null client";
 
     // Execute any client bootstraps synchronously before marking the Client as successful.
@@ -199,7 +202,7 @@ public class TransportClientFactory implements Closeable {
     logger.debug("Connection to {} successful, running bootstraps...", address);
     try {
       for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
-        clientBootstrap.doBootstrap(client);
+        clientBootstrap.doBootstrap(client, channel);
       }
     } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
       long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 33aa134..185ba2e 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -17,8 +17,12 @@
 
 package org.apache.spark.network.sasl;
 
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -33,14 +37,24 @@ import org.apache.spark.network.util.TransportConf;
 public class SaslClientBootstrap implements TransportClientBootstrap {
   private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
 
+  private final boolean encrypt;
   private final TransportConf conf;
   private final String appId;
   private final SecretKeyHolder secretKeyHolder;
 
   public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+    this(conf, appId, secretKeyHolder, false);
+  }
+
+  public SaslClientBootstrap(
+      TransportConf conf,
+      String appId,
+      SecretKeyHolder secretKeyHolder,
+      boolean encrypt) {
     this.conf = conf;
     this.appId = appId;
     this.secretKeyHolder = secretKeyHolder;
+    this.encrypt = encrypt;
   }
 
   /**
@@ -49,8 +63,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
    * due to mismatch.
    */
   @Override
-  public void doBootstrap(TransportClient client) {
-    SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+  public void doBootstrap(TransportClient client, Channel channel) {
+    SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
     try {
       byte[] payload = saslClient.firstToken();
 
@@ -62,13 +76,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
         byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
         payload = saslClient.response(response);
       }
+
+      if (encrypt) {
+        if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
+          throw new RuntimeException(
+            new SaslException("Encryption requests by negotiated non-encrypted connection."));
+        }
+        SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
+        saslClient = null;
+        logger.debug("Channel {} configured for SASL encryption.", client);
+      }
     } finally {
-      try {
-        // Once authentication is complete, the server will trust all remaining communication.
-        saslClient.dispose();
-      } catch (RuntimeException e) {
-        logger.error("Error while disposing SASL client", e);
+      if (saslClient != null) {
+        try {
+          // Once authentication is complete, the server will trust all remaining communication.
+          saslClient.dispose();
+        } catch (RuntimeException e) {
+          logger.error("Error while disposing SASL client", e);
+        }
       }
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
new file mode 100644
index 0000000..127335e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.FileRegion;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Provides SASL-based encription for transport channels. The single method exposed by this
+ * class installs the needed channel handlers on a connected channel.
+ */
+class SaslEncryption {
+
+  @VisibleForTesting
+  static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
+
+  /**
+   * Adds channel handlers that perform encryption / decryption of data using SASL.
+   *
+   * @param channel The channel.
+   * @param backend The SASL backend.
+   * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
+   *                             memory usage.
+   */
+  static void addToChannel(
+      Channel channel,
+      SaslEncryptionBackend backend,
+      int maxOutboundBlockSize) {
+    channel.pipeline()
+      .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
+      .addFirst("saslDecryption", new DecryptionHandler(backend))
+      .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
+  }
+
+  private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
+
+    private final int maxOutboundBlockSize;
+    private final SaslEncryptionBackend backend;
+
+    EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
+      this.backend = backend;
+      this.maxOutboundBlockSize = maxOutboundBlockSize;
+    }
+
+    /**
+     * Wrap the incoming message in an implementation that will perform encryption lazily. This is
+     * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
+     * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
+     * does not guarantee any ordering.
+     */
+    @Override
+    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+      throws Exception {
+
+      ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
+    }
+
+    @Override
+    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+      try {
+        backend.dispose();
+      } finally {
+        super.handlerRemoved(ctx);
+      }
+    }
+
+  }
+
+  private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
+
+    private final SaslEncryptionBackend backend;
+
+    DecryptionHandler(SaslEncryptionBackend backend) {
+      this.backend = backend;
+    }
+
+    @Override
+    protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
+      throws Exception {
+
+      byte[] data;
+      int offset;
+      int length = msg.readableBytes();
+      if (msg.hasArray()) {
+        data = msg.array();
+        offset = msg.arrayOffset();
+        msg.skipBytes(length);
+      } else {
+        data = new byte[length];
+        msg.readBytes(data);
+        offset = 0;
+      }
+
+      out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
+    }
+
+  }
+
+  @VisibleForTesting
+  static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
+
+    private final SaslEncryptionBackend backend;
+    private final boolean isByteBuf;
+    private final ByteBuf buf;
+    private final FileRegion region;
+
+    /**
+     * A channel used to buffer input data for encryption. The channel has an upper size bound
+     * so that if the input is larger than the allowed buffer, it will be broken into multiple
+     * chunks.
+     */
+    private final ByteArrayWritableChannel byteChannel;
+
+    private ByteBuf currentHeader;
+    private ByteBuffer currentChunk;
+    private long currentChunkSize;
+    private long currentReportedBytes;
+    private long unencryptedChunkSize;
+    private long transferred;
+
+    EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
+      Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
+        "Unrecognized message type: %s", msg.getClass().getName());
+      this.backend = backend;
+      this.isByteBuf = msg instanceof ByteBuf;
+      this.buf = isByteBuf ? (ByteBuf) msg : null;
+      this.region = isByteBuf ? null : (FileRegion) msg;
+      this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
+    }
+
+    /**
+     * Returns the size of the original (unencrypted) message.
+     *
+     * This makes assumptions about how netty treats FileRegion instances, because there's no way
+     * to know beforehand what will be the size of the encrypted message. Namely, it assumes
+     * that netty will try to transfer data from this message while
+     * <code>transfered() < count()</code>. So these two methods return, technically, wrong data,
+     * but netty doesn't know better.
+     */
+    @Override
+    public long count() {
+      return isByteBuf ? buf.readableBytes() : region.count();
+    }
+
+    @Override
+    public long position() {
+      return 0;
+    }
+
+    /**
+     * Returns an approximation of the amount of data transferred. See {@link #count()}.
+     */
+    @Override
+    public long transfered() {
+      return transferred;
+    }
+
+    /**
+     * Transfers data from the original message to the channel, encrypting it in the process.
+     *
+     * This method also breaks down the original message into smaller chunks when needed. This
+     * is done to keep memory usage under control. This avoids having to copy the whole message
+     * data into memory at once, and can avoid ballooning memory usage when transferring large
+     * messages such as shuffle blocks.
+     *
+     * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward
+     * until a whole chunk has been written. This is done because the code can't use the actual
+     * number of bytes written to the channel as the transferred count (see {@link #count()}).
+     * Instead, once an encrypted chunk is written to the output (including its header), the
+     * size of the original block will be added to the {@link #transfered()} amount.
+     */
+    @Override
+    public long transferTo(final WritableByteChannel target, final long position)
+      throws IOException {
+
+      Preconditions.checkArgument(position == transfered(), "Invalid position.");
+
+      long reportedWritten = 0L;
+      long actuallyWritten = 0L;
+      do {
+        if (currentChunk == null) {
+          nextChunk();
+        }
+
+        if (currentHeader.readableBytes() > 0) {
+          int bytesWritten = target.write(currentHeader.nioBuffer());
+          currentHeader.skipBytes(bytesWritten);
+          actuallyWritten += bytesWritten;
+          if (currentHeader.readableBytes() > 0) {
+            // Break out of loop if there are still header bytes left to write.
+            break;
+          }
+        }
+
+        actuallyWritten += target.write(currentChunk);
+        if (!currentChunk.hasRemaining()) {
+          // Only update the count of written bytes once a full chunk has been written.
+          // See method javadoc.
+          long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
+          reportedWritten += chunkBytesRemaining;
+          transferred += chunkBytesRemaining;
+          currentHeader.release();
+          currentHeader = null;
+          currentChunk = null;
+          currentChunkSize = 0;
+          currentReportedBytes = 0;
+        }
+      } while (currentChunk == null && transfered() + reportedWritten < count());
+
+      // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
+      // we return 1 until we can (i.e. until the reported count would actually match the size
+      // of the current chunk), at which point we resort to returning 0 so that the counts still
+      // match, at the cost of some performance. That situation should be rare, though.
+      if (reportedWritten != 0L) {
+        return reportedWritten;
+      }
+
+      if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
+        transferred += 1L;
+        currentReportedBytes += 1L;
+        return 1L;
+      }
+
+      return 0L;
+    }
+
+    private void nextChunk() throws IOException {
+      byteChannel.reset();
+      if (isByteBuf) {
+        int copied = byteChannel.write(buf.nioBuffer());
+        buf.skipBytes(copied);
+      } else {
+        region.transferTo(byteChannel, region.transfered());
+      }
+
+      byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
+      this.currentChunk = ByteBuffer.wrap(encrypted);
+      this.currentChunkSize = encrypted.length;
+      this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
+      this.unencryptedChunkSize = byteChannel.length();
+    }
+
+    @Override
+    protected void deallocate() {
+      if (currentHeader != null) {
+        currentHeader.release();
+      }
+      if (buf != null) {
+        buf.release();
+      }
+      if (region != null) {
+        region.release();
+      }
+    }
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
new file mode 100644
index 0000000..89b78bc
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.sasl.SaslException;
+
+interface SaslEncryptionBackend {
+
+  /** Disposes of resources used by the backend. */
+  void dispose();
+
+  /** Encrypt data. */
+  byte[] wrap(byte[] data, int offset, int len) throws SaslException;
+
+  /** Decrypt data. */
+  byte[] unwrap(byte[] data, int offset, int len) throws SaslException;
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 026cbd2..be6165c 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -17,10 +17,10 @@
 
 package org.apache.spark.network.sasl;
 
-import java.util.concurrent.ConcurrentMap;
+import javax.security.sasl.Sasl;
 
-import com.google.common.collect.Maps;
 import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -28,6 +28,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
 
 /**
  * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
@@ -37,8 +38,14 @@ import org.apache.spark.network.server.StreamManager;
  * Note that the authentication process consists of multiple challenge-response pairs, each of
  * which are individual RPCs.
  */
-public class SaslRpcHandler extends RpcHandler {
-  private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+class SaslRpcHandler extends RpcHandler {
+  private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+  /** Transport configuration. */
+  private final TransportConf conf;
+
+  /** The client channel. */
+  private final Channel channel;
 
   /** RpcHandler we will delegate to for authenticated connections. */
   private final RpcHandler delegate;
@@ -46,19 +53,25 @@ public class SaslRpcHandler extends RpcHandler {
   /** Class which provides secret keys which are shared by server and client on a per-app basis. */
   private final SecretKeyHolder secretKeyHolder;
 
-  /** Maps each channel to its SASL authentication state. */
-  private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+  private SparkSaslServer saslServer;
+  private boolean isComplete;
 
-  public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+  SaslRpcHandler(
+      TransportConf conf,
+      Channel channel,
+      RpcHandler delegate,
+      SecretKeyHolder secretKeyHolder) {
+    this.conf = conf;
+    this.channel = channel;
     this.delegate = delegate;
     this.secretKeyHolder = secretKeyHolder;
-    this.channelAuthenticationMap = Maps.newConcurrentMap();
+    this.saslServer = null;
+    this.isComplete = false;
   }
 
   @Override
   public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
-    SparkSaslServer saslServer = channelAuthenticationMap.get(client);
-    if (saslServer != null && saslServer.isComplete()) {
+    if (isComplete) {
       // Authentication complete, delegate to base handler.
       delegate.receive(client, message, callback);
       return;
@@ -68,15 +81,30 @@ public class SaslRpcHandler extends RpcHandler {
 
     if (saslServer == null) {
       // First message in the handshake, setup the necessary state.
-      saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
-      channelAuthenticationMap.put(client, saslServer);
+      saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
+        conf.saslServerAlwaysEncrypt());
     }
 
     byte[] response = saslServer.response(saslMessage.payload);
+    callback.onSuccess(response);
+
+    // Setup encryption after the SASL response is sent, otherwise the client can't parse the
+    // response. It's ok to change the channel pipeline here since we are processing an incoming
+    // message, so the pipeline is busy and no new incoming messages will be fed to it before this
+    // method returns. This assumes that the code ensures, through other means, that no outbound
+    // messages are being written to the channel while negotiation is still going on.
     if (saslServer.isComplete()) {
       logger.debug("SASL authentication successful for channel {}", client);
+      isComplete = true;
+      if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
+        logger.debug("Enabling encryption for channel {}", client);
+        SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
+        saslServer = null;
+      } else {
+        saslServer.dispose();
+        saslServer = null;
+      }
     }
-    callback.onSuccess(response);
   }
 
   @Override
@@ -86,9 +114,9 @@ public class SaslRpcHandler extends RpcHandler {
 
   @Override
   public void connectionTerminated(TransportClient client) {
-    SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
     if (saslServer != null) {
       saslServer.dispose();
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
new file mode 100644
index 0000000..f2f9838
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.channel.Channel;
+
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public class SaslServerBootstrap implements TransportServerBootstrap {
+
+  private final TransportConf conf;
+  private final SecretKeyHolder secretKeyHolder;
+
+  public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
+    this.conf = conf;
+    this.secretKeyHolder = secretKeyHolder;
+  }
+
+  /**
+   * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL
+   * negotiation.
+   */
+  public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+    return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
index 9abad1f..94685e9 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.sasl;
 
+import java.io.IOException;
+import java.util.Map;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
@@ -27,9 +29,9 @@ import javax.security.sasl.RealmChoiceCallback;
 import javax.security.sasl.Sasl;
 import javax.security.sasl.SaslClient;
 import javax.security.sasl.SaslException;
-import java.io.IOException;
 
 import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -40,19 +42,25 @@ import static org.apache.spark.network.sasl.SparkSaslServer.*;
  * initial state to the "authenticated" state. This client initializes the protocol via a
  * firstToken, which is then followed by a set of challenges and responses.
  */
-public class SparkSaslClient {
+public class SparkSaslClient implements SaslEncryptionBackend {
   private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
 
   private final String secretKeyId;
   private final SecretKeyHolder secretKeyHolder;
+  private final String expectedQop;
   private SaslClient saslClient;
 
-  public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+  public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
     this.secretKeyId = secretKeyId;
     this.secretKeyHolder = secretKeyHolder;
+    this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
+
+    Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+      .put(Sasl.QOP, expectedQop)
+      .build();
     try {
       this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
-        SASL_PROPS, new ClientCallbackHandler());
+        saslProps, new ClientCallbackHandler());
     } catch (SaslException e) {
       throw Throwables.propagate(e);
     }
@@ -76,6 +84,11 @@ public class SparkSaslClient {
     return saslClient != null && saslClient.isComplete();
   }
 
+  /** Returns the value of a negotiated property. */
+  public Object getNegotiatedProperty(String name) {
+    return saslClient.getNegotiatedProperty(name);
+  }
+
   /**
    * Respond to server's SASL token.
    * @param token contains server's SASL token
@@ -93,6 +106,7 @@ public class SparkSaslClient {
    * Disposes of any system resources or security-sensitive information the
    * SaslClient might be using.
    */
+  @Override
   public synchronized void dispose() {
     if (saslClient != null) {
       try {
@@ -134,4 +148,15 @@ public class SparkSaslClient {
       }
     }
   }
+
+  @Override
+  public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+    return saslClient.wrap(data, offset, len);
+  }
+
+  @Override
+  public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+    return saslClient.unwrap(data, offset, len);
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
index e87b17e..431cb67 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory;
  * initial state to the "authenticated" state. (It is not a server in the sense of accepting
  * connections on some socket.)
  */
-public class SparkSaslServer {
+public class SparkSaslServer implements SaslEncryptionBackend {
   private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
 
   /**
@@ -60,26 +60,37 @@ public class SparkSaslServer {
   static final String DIGEST = "DIGEST-MD5";
 
   /**
-   * The quality of protection is just "auth". This means that we are doing
-   * authentication only, we are not supporting integrity or privacy protection of the
-   * communication channel after authentication. This could be changed to be configurable
-   * in the future.
+   * Quality of protection value that includes encryption.
    */
-  static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
-    .put(Sasl.QOP, "auth")
-    .put(Sasl.SERVER_AUTH, "true")
-    .build();
+  static final String QOP_AUTH_CONF = "auth-conf";
+
+  /**
+   * Quality of protection value that does not include encryption.
+   */
+  static final String QOP_AUTH = "auth";
 
   /** Identifier for a certain secret key within the secretKeyHolder. */
   private final String secretKeyId;
   private final SecretKeyHolder secretKeyHolder;
   private SaslServer saslServer;
 
-  public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+  public SparkSaslServer(
+      String secretKeyId,
+      SecretKeyHolder secretKeyHolder,
+      boolean alwaysEncrypt) {
     this.secretKeyId = secretKeyId;
     this.secretKeyHolder = secretKeyHolder;
+
+    // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
+    // is listed first since it's preferred over the non-encrypted one (if the client also
+    // lists both in the request).
+    String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
+    Map<String, String> saslProps = ImmutableMap.<String, String>builder()
+      .put(Sasl.SERVER_AUTH, "true")
+      .put(Sasl.QOP, qop)
+      .build();
     try {
-      this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
+      this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
         new DigestCallbackHandler());
     } catch (SaslException e) {
       throw Throwables.propagate(e);
@@ -93,6 +104,11 @@ public class SparkSaslServer {
     return saslServer != null && saslServer.isComplete();
   }
 
+  /** Returns the value of a negotiated property. */
+  public Object getNegotiatedProperty(String name) {
+    return saslServer.getNegotiatedProperty(name);
+  }
+
   /**
    * Used to respond to server SASL tokens.
    * @param token Server's SASL token
@@ -110,6 +126,7 @@ public class SparkSaslServer {
    * Disposes of any system resources or security-sensitive information the
    * SaslServer might be using.
    */
+  @Override
   public synchronized void dispose() {
     if (saslServer != null) {
       try {
@@ -122,6 +139,16 @@ public class SparkSaslServer {
     }
   }
 
+  @Override
+  public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
+    return saslServer.wrap(data, offset, len);
+  }
+
+  @Override
+  public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
+    return saslServer.unwrap(data, offset, len);
+  }
+
   /**
    * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
index b7ce854..941ef95 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -19,8 +19,11 @@ package org.apache.spark.network.server;
 
 import java.io.Closeable;
 import java.net.InetSocketAddress;
+import java.util.List;
 import java.util.concurrent.TimeUnit;
 
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.channel.ChannelFuture;
@@ -44,15 +47,23 @@ public class TransportServer implements Closeable {
 
   private final TransportContext context;
   private final TransportConf conf;
+  private final RpcHandler appRpcHandler;
+  private final List<TransportServerBootstrap> bootstraps;
 
   private ServerBootstrap bootstrap;
   private ChannelFuture channelFuture;
   private int port = -1;
 
   /** Creates a TransportServer that binds to the given port, or to any available if 0. */
-  public TransportServer(TransportContext context, int portToBind) {
+  public TransportServer(
+      TransportContext context,
+      int portToBind,
+      RpcHandler appRpcHandler,
+      List<TransportServerBootstrap> bootstraps) {
     this.context = context;
     this.conf = context.getConf();
+    this.appRpcHandler = appRpcHandler;
+    this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
 
     init(portToBind);
   }
@@ -95,7 +106,11 @@ public class TransportServer implements Closeable {
     bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
       @Override
       protected void initChannel(SocketChannel ch) throws Exception {
-        context.initializePipeline(ch);
+        RpcHandler rpcHandler = appRpcHandler;
+        for (TransportServerBootstrap bootstrap : bootstraps) {
+          rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
+        }
+        context.initializePipeline(ch, rpcHandler);
       }
     });
 

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
new file mode 100644
index 0000000..05803ab
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.Channel;
+
+/**
+ * A bootstrap which is executed on a TransportServer's client channel once a client connects
+ * to the server. This allows customizing the client channel to allow for things such as SASL
+ * authentication.
+ */
+public interface TransportServerBootstrap {
+  /**
+   * Customizes the channel to include new features, if needed.
+   *
+   * @param channel The connected channel opened by the client.
+   * @param rpcHandler The RPC handler for the server.
+   * @return The RPC handler to use for the channel.
+   */
+  RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
new file mode 100644
index 0000000..b141572
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * A writable channel that stores the written data in a byte array in memory.
+ */
+public class ByteArrayWritableChannel implements WritableByteChannel {
+
+  private final byte[] data;
+  private int offset;
+
+  public ByteArrayWritableChannel(int size) {
+    this.data = new byte[size];
+  }
+
+  public byte[] getData() {
+    return data;
+  }
+
+  public int length() {
+    return offset;
+  }
+
+  /** Resets the channel so that writing to it will overwrite the existing buffer. */
+  public void reset() {
+    offset = 0;
+  }
+
+  /**
+   * Reads from the given buffer into the internal byte array.
+   */
+  @Override
+  public int write(ByteBuffer src) {
+    int toTransfer = Math.min(src.remaining(), data.length - offset);
+    src.get(data, offset, toTransfer);
+    offset += toTransfer;
+    return toTransfer;
+  }
+
+  @Override
+  public void close() {
+
+  }
+
+  @Override
+  public boolean isOpen() {
+    return true;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 0aef7f1..3b2eff3 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.util;
 
+import com.google.common.primitives.Ints;
+
 /**
  * A central location that tracks all the settings we expose to users.
  */
@@ -112,4 +114,20 @@ public class TransportConf {
   public int portMaxRetries() {
     return conf.getInt("spark.port.maxRetries", 16);
   }
+
+  /**
+   * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
+   */
+  public int maxSaslEncryptedBlockSize() {
+    return Ints.checkedCast(JavaUtils.byteStringAsBytes(
+      conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
+  }
+
+  /**
+   * Whether the server should enforce encryption on SASL-authenticated connections.
+   */
+  public boolean saslServerAlwaysEncrypt() {
+    return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
deleted file mode 100644
index b525ed6..0000000
--- a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network;
-
-import java.nio.ByteBuffer;
-import java.nio.channels.WritableByteChannel;
-
-public class ByteArrayWritableChannel implements WritableByteChannel {
-
-  private final byte[] data;
-  private int offset;
-
-  public ByteArrayWritableChannel(int size) {
-    this.data = new byte[size];
-    this.offset = 0;
-  }
-
-  public byte[] getData() {
-    return data;
-  }
-
-  @Override
-  public int write(ByteBuffer src) {
-    int available = src.remaining();
-    src.get(data, offset, available);
-    offset += available;
-    return available;
-  }
-
-  @Override
-  public void close() {
-
-  }
-
-  @Override
-  public boolean isOpen() {
-    return true;
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 860dd6d..d500bc3 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -39,6 +39,7 @@ import org.apache.spark.network.protocol.RpcFailure;
 import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.RpcResponse;
 import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
 import org.apache.spark.network.util.NettyUtils;
 
 public class ProtocolSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
index ff98509..6c98e73 100644
--- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -29,7 +29,7 @@ import org.junit.Test;
 
 import static org.junit.Assert.*;
 
-import org.apache.spark.network.ByteArrayWritableChannel;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
 
 public class MessageWithHeaderSuite {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 23b4e06..be6632b 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -17,12 +17,47 @@
 
 package org.apache.spark.network.sasl;
 
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static com.google.common.base.Charsets.UTF_8;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
 
+import java.io.File;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.security.sasl.SaslException;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Files;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelOutboundHandlerAdapter;
+import io.netty.channel.ChannelPromise;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
 
 /**
  * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
@@ -44,8 +79,8 @@ public class SparkSaslSuite {
 
   @Test
   public void testMatching() {
-    SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
-    SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+    SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
+    SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
 
     assertFalse(client.isComplete());
     assertFalse(server.isComplete());
@@ -64,11 +99,10 @@ public class SparkSaslSuite {
     assertFalse(client.isComplete());
   }
 
-
   @Test
   public void testNonMatching() {
-    SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
-    SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+    SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
+    SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
 
     assertFalse(client.isComplete());
     assertFalse(server.isComplete());
@@ -86,4 +120,312 @@ public class SparkSaslSuite {
       assertFalse(server.isComplete());
     }
   }
+
+  @Test
+  public void testSaslAuthentication() throws Exception {
+    testBasicSasl(false);
+  }
+
+  @Test
+  public void testSaslEncryption() throws Exception {
+    testBasicSasl(true);
+  }
+
+  private void testBasicSasl(boolean encrypt) throws Exception {
+    RpcHandler rpcHandler = mock(RpcHandler.class);
+    doAnswer(new Answer<Void>() {
+        @Override
+        public Void answer(InvocationOnMock invocation) {
+          byte[] message = (byte[]) invocation.getArguments()[1];
+          RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
+          assertEquals("Ping", new String(message, UTF_8));
+          cb.onSuccess("Pong".getBytes(UTF_8));
+          return null;
+        }
+      })
+      .when(rpcHandler)
+      .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class));
+
+    SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
+    try {
+      byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
+      assertEquals("Pong", new String(response, UTF_8));
+    } finally {
+      ctx.close();
+    }
+  }
+
+  @Test
+  public void testEncryptedMessage() throws Exception {
+    SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
+    byte[] data = new byte[1024];
+    new Random().nextBytes(data);
+    when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
+
+    ByteBuf msg = Unpooled.buffer();
+    try {
+      msg.writeBytes(data);
+
+      // Create a channel with a really small buffer compared to the data. This means that on each
+      // call, the outbound data will not be fully written, so the write() method should return a
+      // dummy count to keep the channel alive when possible.
+      ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
+
+      SaslEncryption.EncryptedMessage emsg =
+        new SaslEncryption.EncryptedMessage(backend, msg, 1024);
+      long count = emsg.transferTo(channel, emsg.transfered());
+      assertTrue(count < data.length);
+      assertTrue(count > 0);
+
+      // Here, the output buffer is full so nothing should be transferred.
+      assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
+
+      // Now there's room in the buffer, but not enough to transfer all the remaining data,
+      // so the dummy count should be returned.
+      channel.reset();
+      assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
+
+      // Eventually, the whole message should be transferred.
+      for (int i = 0; i < data.length / 32 - 2; i++) {
+        channel.reset();
+        assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
+      }
+
+      channel.reset();
+      count = emsg.transferTo(channel, emsg.transfered());
+      assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
+      assertEquals(data.length, emsg.transfered());
+    } finally {
+      msg.release();
+    }
+  }
+
+  @Test
+  public void testEncryptedMessageChunking() throws Exception {
+    File file = File.createTempFile("sasltest", ".txt");
+    try {
+      TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+
+      byte[] data = new byte[8 * 1024];
+      new Random().nextBytes(data);
+      Files.write(data, file);
+
+      SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
+      // It doesn't really matter what we return here, as long as it's not null.
+      when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
+
+      FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
+      SaslEncryption.EncryptedMessage emsg =
+        new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
+
+      ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
+      while (emsg.transfered() < emsg.count()) {
+        channel.reset();
+        emsg.transferTo(channel, emsg.transfered());
+      }
+
+      verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
+    } finally {
+      file.delete();
+    }
+  }
+
+  @Test
+  public void testFileRegionEncryption() throws Exception {
+    final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
+    System.setProperty(blockSizeConf, "1k");
+
+    final AtomicReference<ManagedBuffer> response = new AtomicReference();
+    final File file = File.createTempFile("sasltest", ".txt");
+    SaslTestCtx ctx = null;
+    try {
+      final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+      StreamManager sm = mock(StreamManager.class);
+      when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
+          @Override
+          public ManagedBuffer answer(InvocationOnMock invocation) {
+            return new FileSegmentManagedBuffer(conf, file, 0, file.length());
+          }
+        });
+
+      RpcHandler rpcHandler = mock(RpcHandler.class);
+      when(rpcHandler.getStreamManager()).thenReturn(sm);
+
+      byte[] data = new byte[8 * 1024];
+      new Random().nextBytes(data);
+      Files.write(data, file);
+
+      ctx = new SaslTestCtx(rpcHandler, true, false);
+
+      final Object lock = new Object();
+
+      ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+      doAnswer(new Answer<Void>() {
+          @Override
+          public Void answer(InvocationOnMock invocation) {
+            response.set((ManagedBuffer) invocation.getArguments()[1]);
+            response.get().retain();
+            synchronized (lock) {
+              lock.notifyAll();
+            }
+            return null;
+          }
+        }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
+
+      synchronized (lock) {
+        ctx.client.fetchChunk(0, 0, callback);
+        lock.wait(10 * 1000);
+      }
+
+      verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
+      verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
+
+      byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
+      assertTrue(Arrays.equals(data, received));
+    } finally {
+      file.delete();
+      if (ctx != null) {
+        ctx.close();
+      }
+      if (response.get() != null) {
+        response.get().release();
+      }
+      System.clearProperty(blockSizeConf);
+    }
+  }
+
+  @Test
+  public void testServerAlwaysEncrypt() throws Exception {
+    final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt";
+    System.setProperty(alwaysEncryptConfName, "true");
+
+    SaslTestCtx ctx = null;
+    try {
+      ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
+      fail("Should have failed to connect without encryption.");
+    } catch (Exception e) {
+      assertTrue(e.getCause() instanceof SaslException);
+    } finally {
+      if (ctx != null) {
+        ctx.close();
+      }
+      System.clearProperty(alwaysEncryptConfName);
+    }
+  }
+
+  @Test
+  public void testDataEncryptionIsActuallyEnabled() throws Exception {
+    // This test sets up an encrypted connection but then, using a client bootstrap, removes
+    // the encryption handler from the client side. This should cause the server to not be
+    // able to understand RPCs sent to it and thus close the connection.
+    SaslTestCtx ctx = null;
+    try {
+      ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
+      ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
+      fail("Should have failed to send RPC to server.");
+    } catch (Exception e) {
+      assertFalse(e.getCause() instanceof TimeoutException);
+    } finally {
+      if (ctx != null) {
+        ctx.close();
+      }
+    }
+  }
+
+  private static class SaslTestCtx {
+
+    final TransportClient client;
+    final TransportServer server;
+
+    private final boolean encrypt;
+    private final boolean disableClientEncryption;
+    private final EncryptionCheckerBootstrap checker;
+
+    SaslTestCtx(
+        RpcHandler rpcHandler,
+        boolean encrypt,
+        boolean disableClientEncryption)
+      throws Exception {
+
+      TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+
+      SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
+      when(keyHolder.getSaslUser(anyString())).thenReturn("user");
+      when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
+
+      TransportContext ctx = new TransportContext(conf, rpcHandler);
+
+      this.checker = new EncryptionCheckerBootstrap();
+      this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
+        checker));
+
+      try {
+        List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
+        clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
+        if (disableClientEncryption) {
+          clientBootstraps.add(new EncryptionDisablerBootstrap());
+        }
+
+        this.client = ctx.createClientFactory(clientBootstraps)
+          .createClient(TestUtils.getLocalHost(), server.getPort());
+      } catch (Exception e) {
+        close();
+        throw e;
+      }
+
+      this.encrypt = encrypt;
+      this.disableClientEncryption = disableClientEncryption;
+    }
+
+    void close() {
+      if (!disableClientEncryption) {
+        assertEquals(encrypt, checker.foundEncryptionHandler);
+      }
+      if (client != null) {
+        client.close();
+      }
+      if (server != null) {
+        server.close();
+      }
+    }
+
+  }
+
+  private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
+    implements TransportServerBootstrap {
+
+    boolean foundEncryptionHandler;
+
+    @Override
+    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
+      throws Exception {
+      if (!foundEncryptionHandler) {
+        foundEncryptionHandler =
+          ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
+      }
+      ctx.write(msg, promise);
+    }
+
+    @Override
+    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+      super.handlerRemoved(ctx);
+    }
+
+    @Override
+    public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
+      channel.pipeline().addFirst("encryptionChecker", this);
+      return rpcHandler;
+    }
+
+  }
+
+  private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
+
+    @Override
+    public void doBootstrap(TransportClient client, Channel channel) {
+      channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
+    }
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 6e8018b..612bce5 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle;
 import java.io.IOException;
 import java.util.List;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -46,6 +47,7 @@ public class ExternalShuffleClient extends ShuffleClient {
 
   private final TransportConf conf;
   private final boolean saslEnabled;
+  private final boolean saslEncryptionEnabled;
   private final SecretKeyHolder secretKeyHolder;
 
   private TransportClientFactory clientFactory;
@@ -58,10 +60,15 @@ public class ExternalShuffleClient extends ShuffleClient {
   public ExternalShuffleClient(
       TransportConf conf,
       SecretKeyHolder secretKeyHolder,
-      boolean saslEnabled) {
+      boolean saslEnabled,
+      boolean saslEncryptionEnabled) {
+    Preconditions.checkArgument(
+      !saslEncryptionEnabled || saslEnabled,
+      "SASL encryption can only be enabled if SASL is also enabled.");
     this.conf = conf;
     this.secretKeyHolder = secretKeyHolder;
     this.saslEnabled = saslEnabled;
+    this.saslEncryptionEnabled = saslEncryptionEnabled;
   }
 
   @Override
@@ -70,7 +77,7 @@ public class ExternalShuffleClient extends ShuffleClient {
     TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
     List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
     if (saslEnabled) {
-      bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder));
+      bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
     }
     clientFactory = context.createClientFactory(bootstraps);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index d25283e..382f613 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.network.sasl;
 
 import java.io.IOException;
+import java.util.Arrays;
 
 import com.google.common.collect.Lists;
 import org.junit.After;
@@ -37,6 +38,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.TransportServerBootstrap;
 import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
@@ -72,10 +74,11 @@ public class SaslIntegrationSuite {
   @BeforeClass
   public static void beforeAll() throws IOException {
     SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
-    SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
     conf = new TransportConf(new SystemPropertyConfigProvider());
-    context = new TransportContext(conf, handler);
-    server = context.createServer();
+    context = new TransportContext(conf, new TestRpcHandler());
+
+    TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
+    server = context.createServer(Arrays.asList(bootstrap));
   }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/38d4e9e4/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index 02c10bc..39aa499 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -136,7 +136,7 @@ public class ExternalShuffleIntegrationSuite {
 
     final Semaphore requestsRemaining = new Semaphore(0);
 
-    ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
+    ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
     client.init(APP_ID);
     client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
       new BlockFetchingListener() {
@@ -274,7 +274,7 @@ public class ExternalShuffleIntegrationSuite {
 
   private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
       throws IOException {
-    ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
+    ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
     client.init(APP_ID);
     client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
       executorId, executorInfo);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message