spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject git commit: [SPARK-2583] ConnectionManager error reporting
Date Thu, 07 Aug 2014 00:28:01 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4e008334e -> 17caae48b


[SPARK-2583] ConnectionManager error reporting

This patch modifies the ConnectionManager so that error messages are sent in reply when uncaught
exceptions occur during message processing.  This prevents message senders from hanging while
waiting for an acknowledgment if the remote message processing failed.

This is an updated version of sarutak's PR, #1490.  The main change is to use Futures / Promises
to signal errors.

Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Author: Josh Rosen <joshrosen@apache.org>

Closes #1758 from JoshRosen/connection-manager-fixes and squashes the following commits:

68620cb [Josh Rosen] Fix test in BlockFetcherIteratorSuite:
83673de [Josh Rosen] Error ACKs should trigger IOExceptions, so catch only those exceptions
in the test.
b8bb4d4 [Josh Rosen] Fix manager.id vs managerServer.id typo that broke security tests.
659521f [Josh Rosen] Include previous exception when throwing new one
a2f745c [Josh Rosen] Remove sendMessageReliablySync; callers can wait themselves.
c01c450 [Josh Rosen] Return Try[Message] from sendMessageReliablySync.
f1cd1bb [Josh Rosen] Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager error
reporting
7399c6b [Josh Rosen] Merge remote-tracking branch 'origin/pr/1490' into connection-manager-fixes
ee91bb7 [Kousuke Saruta] Modified BufferMessage.scala to keep the spark code style
9dfd0d8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
e7d9aa6 [Kousuke Saruta] rebase to master
326a17f [Kousuke Saruta] Add test cases to ConnectionManagerSuite.scala for SPARK-2583
2a18d6b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
22d7ebd [Kousuke Saruta] Add test cases to BlockManagerSuite for SPARK-2583
e579302 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
281589c [Kousuke Saruta] Add a test case to BlockFetcherIteratorSuite.scala for fetching block
from remote from successfully
0654128 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
ffaa83d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
12d3de8 [Kousuke Saruta] Added BlockFetcherIteratorSuite.scala
4117b8f [Kousuke Saruta] Modified ConnectionManager to be alble to handle error during processing
message
717c9c3 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
6635467 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
e2b8c4a [Kousuke Saruta] Modify to propagete error using ConnectionManager


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/17caae48
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/17caae48
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/17caae48

Branch: refs/heads/master
Commit: 17caae48b3608552dd6e3ae652043831f932ce95
Parents: 4e00833
Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Authored: Wed Aug 6 17:27:55 2014 -0700
Committer: Patrick Wendell <pwendell@gmail.com>
Committed: Wed Aug 6 17:27:55 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/network/BufferMessage.scala    |   7 +-
 .../spark/network/ConnectionManager.scala       | 143 +++++++++++--------
 .../org/apache/spark/network/Message.scala      |   2 +
 .../spark/network/MessageChunkHeader.scala      |   7 +-
 .../org/apache/spark/network/SenderTest.scala   |   7 +-
 .../spark/storage/BlockFetcherIterator.scala    |   9 +-
 .../spark/storage/BlockManagerWorker.scala      |  30 ++--
 .../spark/network/ConnectionManagerSuite.scala  |  38 ++++-
 .../storage/BlockFetcherIteratorSuite.scala     |  98 ++++++++++++-
 .../spark/storage/BlockManagerSuite.scala       | 110 +++++++++++++-
 10 files changed, 362 insertions(+), 89 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index 04df2f3..af35f1f 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var
ackId:
     val security = if (isSecurityNeg) 1 else 0
     if (size == 0 && !gotChunkForSendingOnce) {
       val newChunk = new MessageChunk(
-        new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
+        new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress),
null)
       gotChunkForSendingOnce = true
       return Some(newChunk)
     }
@@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var
ackId:
         }
         buffer.position(buffer.position + newBuffer.remaining)
         val newChunk = new MessageChunk(new MessageChunkHeader(
-            typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
+          typ, id, size, newBuffer.remaining, ackId,
+          hasError, security, senderAddress), newBuffer)
         gotChunkForSendingOnce = true
         return Some(newChunk)
       }
@@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var
ackId:
       val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
       buffer.position(buffer.position + newBuffer.remaining)
       val newChunk = new MessageChunk(new MessageChunkHeader(
-          typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
+          typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress),
newBuffer)
       return Some(newChunk)
     }
     None

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 4c00225..95f96b8 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network
 
+import java.io.IOException
 import java.nio._
 import java.nio.channels._
 import java.nio.channels.spi._
@@ -45,16 +46,26 @@ private[spark] class ConnectionManager(
     name: String = "Connection manager")
   extends Logging {
 
+  /**
+   * Used by sendMessageReliably to track messages being sent.
+   * @param message the message that was sent
+   * @param connectionManagerId the connection manager that sent this message
+   * @param completionHandler callback that's invoked when the send has completed or failed
+   */
   class MessageStatus(
       val message: Message,
       val connectionManagerId: ConnectionManagerId,
       completionHandler: MessageStatus => Unit) {
 
+    /** This is non-None if message has been ack'd */
     var ackMessage: Option[Message] = None
-    var attempted = false
-    var acked = false
 
-    def markDone() { completionHandler(this) }
+    def markDone(ackMessage: Option[Message]) {
+      this.synchronized {
+        this.ackMessage = ackMessage
+        completionHandler(this)
+      }
+    }
   }
 
   private val selector = SelectorProvider.provider.openSelector()
@@ -442,11 +453,7 @@ private[spark] class ConnectionManager(
             messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
               .foreach(status => {
                 logInfo("Notifying " + status)
-                status.synchronized {
-                  status.attempted = true
-                  status.acked = false
-                  status.markDone()
-                }
+                status.markDone(None)
               })
 
             messageStatuses.retain((i, status) => {
@@ -475,11 +482,7 @@ private[spark] class ConnectionManager(
             for (s <- messageStatuses.values
                  if s.connectionManagerId == sendingConnectionManagerId) {
               logInfo("Notifying " + s)
-              s.synchronized {
-                s.attempted = true
-                s.acked = false
-                s.markDone()
-              }
+              s.markDone(None)
             }
 
             messageStatuses.retain((i, status) => {
@@ -547,13 +550,13 @@ private[spark] class ConnectionManager(
         val securityMsgResp = SecurityMessage.fromResponse(replyToken,
           securityMsg.getConnectionId.toString)
         val message = securityMsgResp.toBufferMessage
-        if (message == null) throw new Exception("Error creating security message")
+        if (message == null) throw new IOException("Error creating security message")
         sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
       } catch  {
         case e: Exception => {
           logError("Error handling sasl client authentication", e)
           waitingConn.close()
-          throw new Exception("Error evaluating sasl response: " + e)
+          throw new IOException("Error evaluating sasl response: ", e)
         }
       }
     }
@@ -661,34 +664,39 @@ private[spark] class ConnectionManager(
               }
             }
           }
-          sentMessageStatus.synchronized {
-            sentMessageStatus.ackMessage = Some(message)
-            sentMessageStatus.attempted = true
-            sentMessageStatus.acked = true
-            sentMessageStatus.markDone()
-          }
+          sentMessageStatus.markDone(Some(message))
         } else {
-          val ackMessage = if (onReceiveCallback != null) {
-            logDebug("Calling back")
-            onReceiveCallback(bufferMessage, connectionManagerId)
-          } else {
-            logDebug("Not calling back as callback is null")
-            None
-          }
+          var ackMessage : Option[Message] = None
+          try {
+            ackMessage = if (onReceiveCallback != null) {
+              logDebug("Calling back")
+              onReceiveCallback(bufferMessage, connectionManagerId)
+            } else {
+              logDebug("Not calling back as callback is null")
+              None
+            }
 
-          if (ackMessage.isDefined) {
-            if (!ackMessage.get.isInstanceOf[BufferMessage]) {
-              logDebug("Response to " + bufferMessage + " is not a buffer message, it is
of type "
-                + ackMessage.get.getClass)
-            } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
-              logDebug("Response to " + bufferMessage + " does not have ack id set")
-              ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+            if (ackMessage.isDefined) {
+              if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+                logDebug("Response to " + bufferMessage + " is not a buffer message, it is
of type "
+                  + ackMessage.get.getClass)
+              } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+                logDebug("Response to " + bufferMessage + " does not have ack id set")
+                ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+              }
+            }
+          } catch {
+            case e: Exception => {
+              logError(s"Exception was thrown while processing message", e)
+              val m = Message.createBufferMessage(bufferMessage.id)
+              m.hasError = true
+              ackMessage = Some(m)
             }
+          } finally {
+            sendMessage(connectionManagerId, ackMessage.getOrElse {
+              Message.createBufferMessage(bufferMessage.id)
+            })
           }
-
-          sendMessage(connectionManagerId, ackMessage.getOrElse {
-            Message.createBufferMessage(bufferMessage.id)
-          })
         }
       }
       case _ => throw new Exception("Unknown type message received")
@@ -800,11 +808,7 @@ private[spark] class ConnectionManager(
             case Some(msgStatus) => {
               messageStatuses -= message.id
               logInfo("Notifying " + msgStatus.connectionManagerId)
-              msgStatus.synchronized {
-                msgStatus.attempted = true
-                msgStatus.acked = false
-                msgStatus.markDone()
-              }
+              msgStatus.markDone(None)
             }
             case None => {
               logError("no messageStatus for failed message id: " + message.id)
@@ -823,11 +827,28 @@ private[spark] class ConnectionManager(
     selector.wakeup()
   }
 
+  /**
+   * Send a message and block until an acknowldgment is received or an error occurs.
+   * @param connectionManagerId the message's destination
+   * @param message the message being sent
+   * @return a Future that either returns the acknowledgment message or captures an exception.
+   */
   def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
-      : Future[Option[Message]] = {
-    val promise = Promise[Option[Message]]
-    val status = new MessageStatus(
-      message, connectionManagerId, s => promise.success(s.ackMessage))
+      : Future[Message] = {
+    val promise = Promise[Message]()
+    val status = new MessageStatus(message, connectionManagerId, s => {
+      s.ackMessage match {
+        case None =>  // Indicates a failure where we either never sent or never got ACK'd
+          promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
+        case Some(ackMessage) =>
+          if (ackMessage.hasError) {
+            promise.failure(
+              new IOException("sendMessageReliably failed with ACK that signalled a remote
error"))
+          } else {
+            promise.success(ackMessage)
+          }
+      }
+    })
     messageStatuses.synchronized {
       messageStatuses += ((message.id, status))
     }
@@ -835,11 +856,6 @@ private[spark] class ConnectionManager(
     promise.future
   }
 
-  def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
-      message: Message): Option[Message] = {
-    Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
-  }
-
   def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
     onReceiveCallback = callback
   }
@@ -862,6 +878,7 @@ private[spark] class ConnectionManager(
 
 
 private[spark] object ConnectionManager {
+  import ExecutionContext.Implicits.global
 
   def main(args: Array[String]) {
     val conf = new SparkConf
@@ -896,7 +913,7 @@ private[spark] object ConnectionManager {
 
     (0 until count).map(i => {
       val bufferMessage = Message.createBufferMessage(buffer.duplicate)
-      manager.sendMessageReliablySync(manager.id, bufferMessage)
+      Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
     })
     println("--------------------------")
     println()
@@ -917,8 +934,10 @@ private[spark] object ConnectionManager {
       val bufferMessage = Message.createBufferMessage(buffer.duplicate)
       manager.sendMessageReliably(manager.id, bufferMessage)
     }).foreach(f => {
-      val g = Await.result(f, 1 second)
-      if (!g.isDefined) println("Failed")
+      f.onFailure {
+        case e => println("Failed due to " + e)
+      }
+      Await.ready(f, 1 second)
     })
     val finishTime = System.currentTimeMillis
 
@@ -952,8 +971,10 @@ private[spark] object ConnectionManager {
       val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
       manager.sendMessageReliably(manager.id, bufferMessage)
     }).foreach(f => {
-      val g = Await.result(f, 1 second)
-      if (!g.isDefined) println("Failed")
+      f.onFailure {
+        case e => println("Failed due to " + e)
+      }
+      Await.ready(f, 1 second)
     })
     val finishTime = System.currentTimeMillis
 
@@ -982,8 +1003,10 @@ private[spark] object ConnectionManager {
           val bufferMessage = Message.createBufferMessage(buffer.duplicate)
           manager.sendMessageReliably(manager.id, bufferMessage)
         }).foreach(f => {
-          val g = Await.result(f, 1 second)
-          if (!g.isDefined) println("Failed")
+          f.onFailure {
+            case e => println("Failed due to " + e)
+          }
+          Await.ready(f, 1 second)
         })
       val finishTime = System.currentTimeMillis
       Thread.sleep(1000)

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/network/Message.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index 7caccfd..04ea50f 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
   var startTime = -1L
   var finishTime = -1L
   var isSecurityNeg = false
+  var hasError = false
 
   def size: Int
 
@@ -87,6 +88,7 @@ private[spark] object Message {
       case BUFFER_MESSAGE => new BufferMessage(header.id,
         ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
     }
+    newMessage.hasError = header.hasError
     newMessage.senderAddress = header.address
     newMessage
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
index ead663e..f3ecca5 100644
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
     val totalSize: Int,
     val chunkSize: Int,
     val other: Int,
+    val hasError: Boolean,
     val securityNeg: Int,
     val address: InetSocketAddress) {
   lazy val buffer = {
@@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
       putInt(totalSize).
       putInt(chunkSize).
       putInt(other).
+      put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
       putInt(securityNeg).
       putInt(ip.size).
       put(ip).
@@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(
 
 
 private[spark] object MessageChunkHeader {
-  val HEADER_SIZE = 44
+  val HEADER_SIZE = 45
 
   def create(buffer: ByteBuffer): MessageChunkHeader = {
     if (buffer.remaining != HEADER_SIZE) {
@@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
     val totalSize = buffer.getInt()
     val chunkSize = buffer.getInt()
     val other = buffer.getInt()
+    val hasError = buffer.get() != 0
     val securityNeg = buffer.getInt()
     val ipSize = buffer.getInt()
     val ipBytes = new Array[Byte](ipSize)
     buffer.get(ipBytes)
     val ip = InetAddress.getByAddress(ipBytes)
     val port = buffer.getInt()
-    new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
+    new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
       new InetSocketAddress(ip, port))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/network/SenderTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index b8ea7c2..ea2ad10 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -20,6 +20,10 @@ package org.apache.spark.network
 import java.nio.ByteBuffer
 import org.apache.spark.{SecurityManager, SparkConf}
 
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
 private[spark] object SenderTest {
   def main(args: Array[String]) {
 
@@ -51,7 +55,8 @@ private[spark] object SenderTest {
       val dataMessage = Message.createBufferMessage(buffer.duplicate)
       val startTime = System.currentTimeMillis
       /* println("Started timer at " + startTime) */
-      val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
+      val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
+      val responseStr: String = Try(Await.result(promise, Duration.Inf))
         .map { response =>
           val buffer = response.asInstanceOf[BufferMessage].buffers(0)
           new String(buffer.array, "utf-8")

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index ccf830e..938af6f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashSet
 import scala.collection.mutable.Queue
+import scala.util.{Failure, Success}
 
 import io.netty.buffer.ByteBuf
 
@@ -118,8 +119,8 @@ object BlockFetcherIterator {
       bytesInFlight += req.size
       val sizeMap = req.blocks.toMap  // so we can look up the size of each blockID
       val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
-      future.onSuccess {
-        case Some(message) => {
+      future.onComplete {
+        case Success(message) => {
           val bufferMessage = message.asInstanceOf[BufferMessage]
           val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
           for (blockMessage <- blockMessageArray) {
@@ -135,8 +136,8 @@ object BlockFetcherIterator {
             logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
           }
         }
-        case None => {
-          logError("Could not get block(s) from " + cmId)
+        case Failure(exception) => {
+          logError("Could not get block(s) from " + cmId, exception)
           for ((blockId, size) <- req.blocks) {
             results.put(new FetchResult(blockId, -1, null))
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index c7766a3..bf002a4 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -23,6 +23,10 @@ import org.apache.spark.Logging
 import org.apache.spark.network._
 import org.apache.spark.util.Utils
 
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.util.{Try, Failure, Success}
+
 /**
  * A network interface for BlockManager. Each slave should have one
  * BlockManagerWorker.
@@ -44,13 +48,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager)
extends
           val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
           Some(new BlockMessageArray(responseMessages).toBufferMessage)
         } catch {
-          case e: Exception => logError("Exception handling buffer message", e)
-          None
+          case e: Exception => {
+            logError("Exception handling buffer message", e)
+            val errorMessage = Message.createBufferMessage(msg.id)
+            errorMessage.hasError = true
+            Some(errorMessage)
+          }
         }
       }
       case otherMessage: Any => {
         logError("Unknown type message received: " + otherMessage)
-        None
+        val errorMessage = Message.createBufferMessage(msg.id)
+        errorMessage.hasError = true
+        Some(errorMessage)
       }
     }
   }
@@ -109,9 +119,9 @@ private[spark] object BlockManagerWorker extends Logging {
     val connectionManager = blockManager.connectionManager
     val blockMessage = BlockMessage.fromPutBlock(msg)
     val blockMessageArray = new BlockMessageArray(blockMessage)
-    val resultMessage = connectionManager.sendMessageReliablySync(
-        toConnManagerId, blockMessageArray.toBufferMessage)
-    resultMessage.isDefined
+    val resultMessage = Try(Await.result(connectionManager.sendMessageReliably(
+        toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
+    resultMessage.isSuccess
   }
 
   def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
@@ -119,10 +129,10 @@ private[spark] object BlockManagerWorker extends Logging {
     val connectionManager = blockManager.connectionManager
     val blockMessage = BlockMessage.fromGetBlock(msg)
     val blockMessageArray = new BlockMessageArray(blockMessage)
-    val responseMessage = connectionManager.sendMessageReliablySync(
-        toConnManagerId, blockMessageArray.toBufferMessage)
+    val responseMessage = Try(Await.result(connectionManager.sendMessageReliably(
+        toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
     responseMessage match {
-      case Some(message) => {
+      case Success(message) => {
         val bufferMessage = message.asInstanceOf[BufferMessage]
         logDebug("Response message received " + bufferMessage)
         BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
@@ -130,7 +140,7 @@ private[spark] object BlockManagerWorker extends Logging {
             return blockMessage.getData
           })
       }
-      case None => logDebug("No response message received")
+      case Failure(exception) => logDebug("No response message received")
     }
     null
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
index 415ad8c..846537d 100644
--- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network
 
+import java.io.IOException
 import java.nio._
 
 import org.apache.spark.{SecurityManager, SparkConf}
@@ -25,6 +26,7 @@ import org.scalatest.FunSuite
 import scala.concurrent.{Await, TimeoutException}
 import scala.concurrent.duration._
 import scala.language.postfixOps
+import scala.util.Try
 
 /**
   * Test the ConnectionManager with various security settings.
@@ -46,7 +48,7 @@ class ConnectionManagerSuite extends FunSuite {
     buffer.flip
 
     val bufferMessage = Message.createBufferMessage(buffer.duplicate)
-    manager.sendMessageReliablySync(manager.id, bufferMessage)
+    Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds)
 
     assert(receivedMessage == true)
 
@@ -79,7 +81,7 @@ class ConnectionManagerSuite extends FunSuite {
 
     (0 until count).map(i => {
       val bufferMessage = Message.createBufferMessage(buffer.duplicate)
-      manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+      Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
     })
 
     assert(numReceivedServerMessages == 10)
@@ -118,7 +120,10 @@ class ConnectionManagerSuite extends FunSuite {
     val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
     buffer.flip
     val bufferMessage = Message.createBufferMessage(buffer.duplicate)
-    manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+    // Expect managerServer to close connection, which we'll report as an error:
+    intercept[IOException] {
+      Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
+    }
 
     assert(numReceivedServerMessages == 0)
     assert(numReceivedMessages == 0)
@@ -163,6 +168,8 @@ class ConnectionManagerSuite extends FunSuite {
         val g = Await.result(f, 1 second)
         assert(false)
       } catch {
+        case i: IOException =>
+          assert(true)
         case e: TimeoutException => {
           // we should timeout here since the client can't do the negotiation
           assert(true)
@@ -209,7 +216,6 @@ class ConnectionManagerSuite extends FunSuite {
     }).foreach(f => {
       try {
         val g = Await.result(f, 1 second)
-        if (!g.isDefined) assert(false) else assert(true)
       } catch {
         case e: Exception => {
           assert(false)
@@ -223,7 +229,31 @@ class ConnectionManagerSuite extends FunSuite {
     managerServer.stop()
   }
 
+  test("Ack error message") {
+    val conf = new SparkConf
+    conf.set("spark.authenticate", "false")
+    val securityManager = new SecurityManager(conf)
+    val manager = new ConnectionManager(0, conf, securityManager)
+    val managerServer = new ConnectionManager(0, conf, securityManager)
+    managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+      throw new Exception
+    })
+
+    val size = 10 * 1024 * 1024
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+    buffer.flip
+    val bufferMessage = Message.createBufferMessage(buffer)
+
+    val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
+
+    intercept[IOException] {
+      Await.result(future, 1 second)
+    }
 
+    manager.stop()
+    managerServer.stop()
+
+  }
 
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
index 8dca2eb..1538995 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
@@ -17,18 +17,22 @@
 
 package org.apache.spark.storage
 
+import java.io.IOException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
+
 import org.scalatest.{FunSuite, Matchers}
-import org.scalatest.PrivateMethodTester._
 
 import org.mockito.Mockito._
 import org.mockito.Matchers.{any, eq => meq}
 import org.mockito.stubbing.Answer
 import org.mockito.invocation.InvocationOnMock
 
-import org.apache.spark._
 import org.apache.spark.storage.BlockFetcherIterator._
-import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
-                                 Message}
+import org.apache.spark.network.{ConnectionManager, Message}
 
 class BlockFetcherIteratorSuite extends FunSuite with Matchers {
 
@@ -137,4 +141,90 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
     assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is
not actually defined") 
   }
 
+  test("block fetch from remote fails using BasicBlockFetcherIterator") {
+    val blockManager = mock(classOf[BlockManager])
+    val connManager = mock(classOf[ConnectionManager])
+    when(blockManager.connectionManager).thenReturn(connManager)
+
+    val f = future {
+      throw new IOException("Send failed or we received an error ACK")
+    }
+    when(connManager.sendMessageReliably(any(),
+      any())).thenReturn(f)
+    when(blockManager.futureExecContext).thenReturn(global)
+
+    when(blockManager.blockManagerId).thenReturn(
+      BlockManagerId("test-client", "test-client", 1, 0))
+    when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
+
+    val blId1 = ShuffleBlockId(0,0,0)
+    val blId2 = ShuffleBlockId(0,1,0)
+    val bmId = BlockManagerId("test-server", "test-server",1 , 0)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, Seq((blId1, 1L), (blId2, 1L)))
+    )
+
+    val iterator = new BasicBlockFetcherIterator(blockManager,
+      blocksByAddress, null)
+
+    iterator.initialize()
+    iterator.foreach{
+      case (_, r) => {
+        (!r.isDefined) should be(true)
+      }
+    }
+  }
+
+  test("block fetch from remote succeed using BasicBlockFetcherIterator") {
+    val blockManager = mock(classOf[BlockManager])
+    val connManager = mock(classOf[ConnectionManager])
+    when(blockManager.connectionManager).thenReturn(connManager)
+
+    val blId1 = ShuffleBlockId(0,0,0)
+    val blId2 = ShuffleBlockId(0,1,0)
+    val buf1 = ByteBuffer.allocate(4)
+    val buf2 = ByteBuffer.allocate(4)
+    buf1.putInt(1)
+    buf1.flip()
+    buf2.putInt(1)
+    buf2.flip()
+    val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
+    val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
+    val blockMessageArray = new BlockMessageArray(
+      Seq(blockMessage1, blockMessage2))
+
+    val bufferMessage = blockMessageArray.toBufferMessage
+    val buffer = ByteBuffer.allocate(bufferMessage.size)
+    val arrayBuffer = new ArrayBuffer[ByteBuffer]
+    bufferMessage.buffers.foreach{ b =>
+      buffer.put(b)
+    }
+    buffer.flip()
+    arrayBuffer += buffer
+
+    val f = future {
+      Message.createBufferMessage(arrayBuffer)
+    }
+    when(connManager.sendMessageReliably(any(),
+      any())).thenReturn(f)
+    when(blockManager.futureExecContext).thenReturn(global)
+
+    when(blockManager.blockManagerId).thenReturn(
+      BlockManagerId("test-client", "test-client", 1, 0))
+    when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
+
+    val bmId = BlockManagerId("test-server", "test-server",1 , 0)
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (bmId, Seq((blId1, 1L), (blId2, 1L)))
+    )
+
+    val iterator = new BasicBlockFetcherIterator(blockManager,
+      blocksByAddress, null)
+    iterator.initialize()
+    iterator.foreach{
+      case (_, r) => {
+        (r.isDefined) should be(true)
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/17caae48/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 0ac0269..94bb2c4 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -25,7 +25,11 @@ import akka.actor._
 import akka.pattern.ask
 import akka.util.Timeout
 
-import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.Matchers.any
+import org.mockito.Mockito.{doAnswer, mock, spy, when}
+import org.mockito.stubbing.Answer
+
 import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.concurrent.Timeouts._
@@ -33,6 +37,7 @@ import org.scalatest.Matchers
 
 import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
 import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.network.{Message, ConnectionManagerId}
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -1000,6 +1005,109 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
     assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
   }
 
+  test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive")
{
+    store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+      securityMgr, mapOutputTracker)
+
+    val worker = spy(new BlockManagerWorker(store))
+    val connManagerId = mock(classOf[ConnectionManagerId])
+
+    // setup request block messages
+    val reqBlId1 = ShuffleBlockId(0,0,0)
+    val reqBlId2 = ShuffleBlockId(0,1,0)
+    val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
+    val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
+    val reqBlockMessages = new BlockMessageArray(
+      Seq(reqBlockMessage1, reqBlockMessage2))
+    val reqBufferMessage = reqBlockMessages.toBufferMessage
+
+    val answer = new Answer[Option[BlockMessage]] {
+      override def answer(invocation: InvocationOnMock)
+          :Option[BlockMessage]= {
+        throw new Exception
+      }
+    }
+
+    doAnswer(answer).when(worker).processBlockMessage(any())
+
+    // Test when exception was thrown during processing block messages
+    var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
+    
+    assert(ackMessage.isDefined, "When Exception was thrown in " +
+      "BlockManagerWorker#processBlockMessage, " +
+      "ackMessage should be defined")
+    assert(ackMessage.get.hasError, "When Exception was thown in " +
+      "BlockManagerWorker#processBlockMessage, " +
+      "ackMessage should have error")
+
+    val notBufferMessage = mock(classOf[Message])
+
+    // Test when not BufferMessage was received
+    ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId)
+    assert(ackMessage.isDefined, "When not BufferMessage was passed to " +
+      "BlockManagerWorker#onBlockMessageReceive, " +
+      "ackMessage should be defined")
+    assert(ackMessage.get.hasError, "When not BufferMessage was passed to " +
+      "BlockManagerWorker#onBlockMessageReceive, " +
+      "ackMessage should have error")
+  }
+
+  test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive")
{
+    store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+      securityMgr, mapOutputTracker)
+
+    val worker = spy(new BlockManagerWorker(store))
+    val connManagerId = mock(classOf[ConnectionManagerId])
+
+    // setup request block messages
+    val reqBlId1 = ShuffleBlockId(0,0,0)
+    val reqBlId2 = ShuffleBlockId(0,1,0)
+    val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
+    val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
+    val reqBlockMessages = new BlockMessageArray(
+      Seq(reqBlockMessage1, reqBlockMessage2))
+
+    val tmpBufferMessage = reqBlockMessages.toBufferMessage
+    val buffer = ByteBuffer.allocate(tmpBufferMessage.size)
+    val arrayBuffer = new ArrayBuffer[ByteBuffer]
+    tmpBufferMessage.buffers.foreach{ b =>
+      buffer.put(b)
+    }
+    buffer.flip()
+    arrayBuffer += buffer
+    val reqBufferMessage = Message.createBufferMessage(arrayBuffer)
+
+    // setup ack block messages
+    val buf1 = ByteBuffer.allocate(4)
+    val buf2 = ByteBuffer.allocate(4)
+    buf1.putInt(1)
+    buf1.flip()
+    buf2.putInt(1)
+    buf2.flip()
+    val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1))
+    val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2))
+
+    val answer = new Answer[Option[BlockMessage]] {
+      override def answer(invocation: InvocationOnMock)
+          :Option[BlockMessage]= {
+        if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq(
+          reqBlockMessage1)) {
+          return Some(ackBlockMessage1)
+        } else {
+          return Some(ackBlockMessage2)
+        }
+      }
+    }
+
+    doAnswer(answer).when(worker).processBlockMessage(any())
+
+    val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
+    assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " +
+      "was executed successfully, ackMessage should be defined")
+    assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " +
+      "was executed successfully, ackMessage should not have error")
+  }
+
   test("reserve/release unroll memory") {
     store = makeBlockManager(12000)
     val memoryStore = store.memoryStore


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message