spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [1/5] [SPARK-2468] Netty based block server / client module
Date Fri, 15 Aug 2014 02:11:35 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.1 72e730e98 -> 3f23d2a38


http://git-wip-us.apache.org/repos/asf/spark/blob/3f23d2a3/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
new file mode 100644
index 0000000..ef3478a
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.io.{RandomAccessFile, File}
+import java.nio.ByteBuffer
+import java.util.{Collections, HashSet}
+import java.util.concurrent.{TimeUnit, Semaphore}
+
+import scala.collection.JavaConversions._
+
+import io.netty.buffer.{ByteBufUtil, Unpooled}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.netty.client.{ReferenceCountedBuffer, BlockFetchingClientFactory}
+import org.apache.spark.network.netty.server.BlockServer
+import org.apache.spark.storage.{FileSegment, BlockDataProvider}
+
+
+/**
+ * Test suite that makes sure the server and the client implementations share the same protocol.
+ */
+class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
+
+  val bufSize = 100000
+  var buf: ByteBuffer = _
+  var testFile: File = _
+  var server: BlockServer = _
+  var clientFactory: BlockFetchingClientFactory = _
+
+  val bufferBlockId = "buffer_block"
+  val fileBlockId = "file_block"
+
+  val fileContent = new Array[Byte](1024)
+  scala.util.Random.nextBytes(fileContent)
+
+  override def beforeAll() = {
+    buf = ByteBuffer.allocate(bufSize)
+    for (i <- 1 to bufSize) {
+      buf.put(i.toByte)
+    }
+    buf.flip()
+
+    testFile = File.createTempFile("netty-test-file", "txt")
+    val fp = new RandomAccessFile(testFile, "rw")
+    fp.write(fileContent)
+    fp.close()
+
+    server = new BlockServer(new SparkConf, new BlockDataProvider {
+      override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+        if (blockId == bufferBlockId) {
+          Right(buf)
+        } else if (blockId == fileBlockId) {
+          Left(new FileSegment(testFile, 10, testFile.length - 25))
+        } else {
+          throw new Exception("Unknown block id " + blockId)
+        }
+      }
+    })
+
+    clientFactory = new BlockFetchingClientFactory(new SparkConf)
+  }
+
+  override def afterAll() = {
+    server.stop()
+    clientFactory.stop()
+  }
+
+  /** A ByteBuf for buffer_block */
+  lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
+
+  /** A ByteBuf for file_block */
+  lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length
- 25)
+
+  def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String])
=
+  {
+    val client = clientFactory.createClient(server.hostName, server.port)
+    val sem = new Semaphore(0)
+    val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
+    val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
+    val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
+
+    client.fetchBlocks(
+      blockIds,
+      (blockId, buf) => {
+        receivedBlockIds.add(blockId)
+        buf.retain()
+        receivedBuffers.add(buf)
+        sem.release()
+      },
+      (blockId, errorMsg) => {
+        errorBlockIds.add(blockId)
+        sem.release()
+      }
+    )
+    if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
+      fail("Timeout getting response from the server")
+    }
+    client.close()
+    (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
+  }
+
+  test("fetch a ByteBuffer block") {
+    val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
+    assert(blockIds === Set(bufferBlockId))
+    assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
+    assert(failBlockIds.isEmpty)
+    buffers.foreach(_.release())
+  }
+
+  test("fetch a FileSegment block via zero-copy send") {
+    val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
+    assert(blockIds === Set(fileBlockId))
+    assert(buffers.map(_.underlying) === Set(fileBlockReference))
+    assert(failBlockIds.isEmpty)
+    buffers.foreach(_.release())
+  }
+
+  test("fetch a non-existent block") {
+    val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
+    assert(blockIds.isEmpty)
+    assert(buffers.isEmpty)
+    assert(failBlockIds === Set("random-block"))
+  }
+
+  test("fetch both ByteBuffer block and FileSegment block") {
+    val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
+    assert(blockIds === Set(bufferBlockId, fileBlockId))
+    assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
+    assert(failBlockIds.isEmpty)
+    buffers.foreach(_.release())
+  }
+
+  test("fetch both ByteBuffer block and a non-existent block") {
+    val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
+    assert(blockIds === Set(bufferBlockId))
+    assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
+    assert(failBlockIds === Set("random-block"))
+    buffers.foreach(_.release())
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3f23d2a3/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
new file mode 100644
index 0000000..9afdad6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.client
+
+import java.nio.ByteBuffer
+
+import io.netty.buffer.Unpooled
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+
+class BlockFetchingClientHandlerSuite extends FunSuite {
+
+  test("handling block data (successful fetch)") {
+    val blockId = "test_block"
+    val blockData = "blahblahblahblahblah"
+    val totalLength = 4 + blockId.length + blockData.length
+
+    var parsedBlockId: String = ""
+    var parsedBlockData: String = ""
+    val handler = new BlockFetchingClientHandler
+    handler.blockFetchSuccessCallback = (bid, refCntBuf) => {
+      parsedBlockId = bid
+      val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
+      refCntBuf.byteBuffer().get(bytes)
+      parsedBlockData = new String(bytes)
+    }
+
+    val channel = new EmbeddedChannel(handler)
+    val buf = ByteBuffer.allocate(totalLength + 4)  // 4 bytes for the length field itself
+    buf.putInt(totalLength)
+    buf.putInt(blockId.length)
+    buf.put(blockId.getBytes)
+    buf.put(blockData.getBytes)
+    buf.flip()
+
+    channel.writeInbound(Unpooled.wrappedBuffer(buf))
+    assert(parsedBlockId === blockId)
+    assert(parsedBlockData === blockData)
+
+    channel.close()
+  }
+
+  test("handling error message (failed fetch)") {
+    val blockId = "test_block"
+    val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
+    val totalLength = 4 + blockId.length + errorMsg.length
+
+    var parsedBlockId: String = ""
+    var parsedErrorMsg: String = ""
+    val handler = new BlockFetchingClientHandler
+    handler.blockFetchFailureCallback = (bid, msg) => {
+      parsedBlockId = bid
+      parsedErrorMsg = msg
+    }
+
+    val channel = new EmbeddedChannel(handler)
+    val buf = ByteBuffer.allocate(totalLength + 4)  // 4 bytes for the length field itself
+    buf.putInt(totalLength)
+    buf.putInt(-blockId.length)
+    buf.put(blockId.getBytes)
+    buf.put(errorMsg.getBytes)
+    buf.flip()
+
+    channel.writeInbound(Unpooled.wrappedBuffer(buf))
+    assert(parsedBlockId === blockId)
+    assert(parsedErrorMsg === errorMsg)
+
+    channel.close()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3f23d2a3/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
new file mode 100644
index 0000000..3ee281c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+
+class BlockHeaderEncoderSuite extends FunSuite {
+
+  test("encode normal block data") {
+    val blockId = "test_block"
+    val channel = new EmbeddedChannel(new BlockHeaderEncoder)
+    channel.writeOutbound(new BlockHeader(17, blockId, None))
+    val out = channel.readOutbound().asInstanceOf[ByteBuf]
+    assert(out.readInt() === 4 + blockId.length + 17)
+    assert(out.readInt() === blockId.length)
+
+    val blockIdBytes = new Array[Byte](blockId.length)
+    out.readBytes(blockIdBytes)
+    assert(new String(blockIdBytes) === blockId)
+    assert(out.readableBytes() === 0)
+
+    channel.close()
+  }
+
+  test("encode error message") {
+    val blockId = "error_block"
+    val errorMsg = "error encountered"
+    val channel = new EmbeddedChannel(new BlockHeaderEncoder)
+    channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
+    val out = channel.readOutbound().asInstanceOf[ByteBuf]
+    assert(out.readInt() === 4 + blockId.length + errorMsg.length)
+    assert(out.readInt() === -blockId.length)
+
+    val blockIdBytes = new Array[Byte](blockId.length)
+    out.readBytes(blockIdBytes)
+    assert(new String(blockIdBytes) === blockId)
+
+    val errorMsgBytes = new Array[Byte](errorMsg.length)
+    out.readBytes(errorMsgBytes)
+    assert(new String(errorMsgBytes) === errorMsg)
+    assert(out.readableBytes() === 0)
+
+    channel.close()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3f23d2a3/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
new file mode 100644
index 0000000..12f6d87
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty.server
+
+import java.io.File
+import java.nio.ByteBuffer
+
+import io.netty.buffer.{Unpooled, ByteBuf}
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion}
+import io.netty.channel.embedded.EmbeddedChannel
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.storage.{BlockDataProvider, FileSegment}
+
+
+class BlockServerHandlerSuite extends FunSuite {
+
+  test("ByteBuffer block") {
+    val expectedBlockId = "test_bytebuffer_block"
+    val buf = ByteBuffer.allocate(10000)
+    for (i <- 1 to 10000) {
+      buf.put(i.toByte)
+    }
+    buf.flip()
+
+    val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
+      override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf)
+    }))
+
+    channel.writeInbound(expectedBlockId)
+    assert(channel.outboundMessages().size === 2)
+
+    val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
+    val out2 = channel.readOutbound().asInstanceOf[ByteBuf]
+
+    assert(out1.blockId === expectedBlockId)
+    assert(out1.blockSize === buf.remaining)
+    assert(out1.error === None)
+
+    assert(out2.equals(Unpooled.wrappedBuffer(buf)))
+
+    channel.close()
+  }
+
+  test("FileSegment block via zero-copy") {
+    val expectedBlockId = "test_file_block"
+    val url = Thread.currentThread.getContextClassLoader.getResource("netty-test-file.txt")
+    val testFile = new File(url.toURI)
+
+    val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
+      override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
+        Left(new FileSegment(testFile, 15, testFile.length - 25))
+      }
+    }))
+
+    channel.writeInbound(expectedBlockId)
+    assert(channel.outboundMessages().size === 2)
+
+    val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
+    val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion]
+
+    assert(out1.blockId === expectedBlockId)
+    assert(out1.blockSize === testFile.length - 25)
+    assert(out1.error === None)
+
+    assert(out2.count === testFile.length - 25)
+    assert(out2.position === 15)
+  }
+
+  test("pipeline exception propagation") {
+    val blockServerHandler = new BlockServerHandler(new BlockDataProvider {
+      override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ???
+    })
+    val exceptionHandler = new SimpleChannelInboundHandler[String]() {
+      override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = {
+        throw new Exception("this is an error")
+      }
+    }
+
+    val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler)
+    assert(channel.isOpen)
+    channel.writeInbound("a message to trigger the error")
+    assert(!channel.isOpen)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3f23d2a3/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index c87f776..da401c9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -419,7 +419,7 @@
       <dependency>
         <groupId>io.netty</groupId>
         <artifactId>netty-all</artifactId>
-        <version>4.0.17.Final</version>
+        <version>4.0.22.Final</version>
       </dependency>
       <dependency>
         <groupId>org.apache.derby</groupId>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message