spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject [1/3] [SPARK-1332] Improve Spark Streaming's Network Receiver and InputDStream API [WIP]
Date Tue, 22 Apr 2014 02:04:59 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5a5b3346c -> 04c37b6f7


http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
new file mode 100644
index 0000000..3d2537f
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue}
+import scala.language.existentials
+
+import akka.actor._
+import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.SparkContext._
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.{StreamingContext, Time}
+import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
+import org.apache.spark.util.AkkaUtils
+
+/** Information about receiver */
+case class ReceiverInfo(streamId: Int, typ: String, location: String) {
+  override def toString = s"$typ-$streamId"
+}
+
+/** Information about blocks received by the receiver */
+case class ReceivedBlockInfo(
+    streamId: Int,
+    blockId: StreamBlockId,
+    numRecords: Long,
+    metadata: Any
+  )
+
+/**
+ * Messages used by the NetworkReceiver and the ReceiverTracker to communicate
+ * with each other.
+ */
+private[streaming] sealed trait ReceiverTrackerMessage
+private[streaming] case class RegisterReceiver(
+    streamId: Int,
+    typ: String,
+    host: String,
+    receiverActor: ActorRef
+  ) extends ReceiverTrackerMessage
+private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo)
+  extends ReceiverTrackerMessage
+private[streaming] case class ReportError(streamId: Int, message: String, error: String)
+private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String)
+  extends ReceiverTrackerMessage
+
+/**
+ * This class manages the execution of the receivers of NetworkInputDStreams. Instance of
+ * this class must be created after all input streams have been added and StreamingContext.start()
+ * has been called because it needs the final set of input streams at the time of instantiation.
+ */
+private[streaming]
+class ReceiverTracker(ssc: StreamingContext) extends Logging {
+
+  val receiverInputStreams = ssc.graph.getReceiverInputStreams()
+  val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*)
+  val receiverExecutor = new ReceiverLauncher()
+  val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef]
+  val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
+    with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
+  val timeout = AkkaUtils.askTimeout(ssc.conf)
+  val listenerBus = ssc.scheduler.listenerBus
+
+  // actor is created when generator starts.
+  // This not being null means the tracker has been started and not stopped
+  var actor: ActorRef = null
+  var currentTime: Time = null
+
+  /** Start the actor and receiver execution thread. */
+  def start() = synchronized {
+    if (actor != null) {
+      throw new SparkException("ReceiverTracker already started")
+    }
+
+    if (!receiverInputStreams.isEmpty) {
+      actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor),
+        "ReceiverTracker")
+      receiverExecutor.start()
+      logInfo("ReceiverTracker started")
+    }
+  }
+
+  /** Stop the receiver execution thread. */
+  def stop() = synchronized {
+    if (!receiverInputStreams.isEmpty && actor != null) {
+      // First, stop the receivers
+      receiverExecutor.stop()
+
+      // Finally, stop the actor
+      ssc.env.actorSystem.stop(actor)
+      actor = null
+      logInfo("ReceiverTracker stopped")
+    }
+  }
+
+  /** Return all the blocks received from a receiver. */
+  def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = {
+    val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true)
+    logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks")
+    receivedBlockInfo.toArray
+  }
+
+  private def getReceivedBlockInfoQueue(streamId: Int) = {
+    receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo])
+  }
+
+  /** Register a receiver */
+  def registerReceiver(
+      streamId: Int,
+      typ: String,
+      host: String,
+      receiverActor: ActorRef,
+      sender: ActorRef
+    ) {
+    if (!receiverInputStreamMap.contains(streamId)) {
+      throw new Exception("Register received for unexpected id " + streamId)
+    }
+    receiverInfo += ((streamId, receiverActor))
+    ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(
+      ReceiverInfo(streamId, typ, host)
+    ))
+    logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address)
+  }
+
+  /** Deregister a receiver */
+  def deregisterReceiver(streamId: Int, message: String, error: String) {
+    receiverInfo -= streamId
+    ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(streamId, message, error))
+    val messageWithError = if (error != null && !error.isEmpty) {
+      s"$message - $error"
+    } else {
+      s"$message"
+    }
+    logError(s"Deregistered receiver for stream $streamId: $messageWithError")
+  }
+
+  /** Add new blocks for the given stream */
+  def addBlocks(receivedBlockInfo: ReceivedBlockInfo) {
+    getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+    logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " +
+      receivedBlockInfo.blockId)
+  }
+
+  /** Report error sent by a receiver */
+  def reportError(streamId: Int, message: String, error: String) {
+    ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(streamId, message, error))
+    val messageWithError = if (error != null && !error.isEmpty) {
+      s"$message - $error"
+    } else {
+      s"$message"
+    }
+    logWarning(s"Error reported by receiver for stream $streamId: $messageWithError")
+  }
+
+  /** Check if any blocks are left to be processed */
+  def hasMoreReceivedBlockIds: Boolean = {
+    !receivedBlockInfo.values.forall(_.isEmpty)
+  }
+
+  /** Actor to receive messages from the receivers. */
+  private class ReceiverTrackerActor extends Actor {
+    def receive = {
+      case RegisterReceiver(streamId, typ, host, receiverActor) =>
+        registerReceiver(streamId, typ, host, receiverActor, sender)
+        sender ! true
+      case AddBlock(receivedBlockInfo) =>
+        addBlocks(receivedBlockInfo)
+      case ReportError(streamId, message, error) =>
+        reportError(streamId, message, error)
+      case DeregisterReceiver(streamId, message, error) =>
+        deregisterReceiver(streamId, message, error)
+        sender ! true
+    }
+  }
+
+  /** This thread class runs all the receivers on the cluster.  */
+  class ReceiverLauncher {
+    @transient val env = ssc.env
+    @transient val thread  = new Thread() {
+      override def run() {
+        try {
+          SparkEnv.set(env)
+          startReceivers()
+        } catch {
+          case ie: InterruptedException => logInfo("ReceiverLauncher interrupted")
+        }
+      }
+    }
+
+    def start() {
+      thread.start()
+    }
+
+    def stop() {
+      // Send the stop signal to all the receivers
+      stopReceivers()
+
+      // Wait for the Spark job that runs the receivers to be over
+      // That is, for the receivers to quit gracefully.
+      thread.join(10000)
+
+      // Check if all the receivers have been deregistered or not
+      if (!receiverInfo.isEmpty) {
+        logWarning("All of the receivers have not deregistered, " + receiverInfo)
+      } else {
+        logInfo("All of the receivers have deregistered successfully")
+      }
+    }
+
+    /**
+     * Get the receivers from the ReceiverInputDStreams, distributes them to the
+     * worker nodes as a parallel collection, and runs them.
+     */
+    private def startReceivers() {
+      val receivers = receiverInputStreams.map(nis => {
+        val rcvr = nis.getReceiver()
+        rcvr.setReceiverId(nis.id)
+        rcvr
+      })
+
+      // Right now, we only honor preferences if all receivers have them
+      val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_
&& _)
+
+      // Create the parallel collection of receivers to distributed them on the worker nodes
+      val tempRDD =
+        if (hasLocationPreferences) {
+          val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get)))
+          ssc.sc.makeRDD[Receiver[_]](receiversWithPreferences)
+        }
+        else {
+          ssc.sc.makeRDD(receivers, receivers.size)
+        }
+
+      // Function to start the receiver on the worker node
+      val startReceiver = (iterator: Iterator[Receiver[_]]) => {
+        if (!iterator.hasNext) {
+          throw new SparkException(
+            "Could not start receiver as object not found.")
+        }
+        val receiver = iterator.next()
+        val executor = new ReceiverSupervisorImpl(receiver, SparkEnv.get)
+        executor.start()
+        executor.awaitTermination()
+      }
+      // Run the dummy Spark job to ensure that all slaves have registered.
+      // This avoids all the receivers to be scheduled on the same node.
+      if (!ssc.sparkContext.isLocal) {
+        ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect()
+      }
+
+      // Distribute the receivers and start them
+      logInfo("Starting " + receivers.length + " receivers")
+      ssc.sparkContext.runJob(tempRDD, startReceiver)
+      logInfo("All of the receivers have been terminated")
+    }
+
+    /** Stops the receivers. */
+    private def stopReceivers() {
+      // Signal the receivers to stop
+      receiverInfo.values.foreach(_ ! StopReceiver)
+      logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
index 5db40eb..9d6ec1f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.streaming.scheduler
 
 import scala.collection.mutable.Queue
+
 import org.apache.spark.util.Distribution
 
 /** Base trait for events related to StreamingListener */
@@ -26,8 +27,13 @@ sealed trait StreamingListenerEvent
 case class StreamingListenerBatchSubmitted(batchInfo: BatchInfo) extends StreamingListenerEvent
 case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
 case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
 case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo)
   extends StreamingListenerEvent
+case class StreamingListenerReceiverError(streamId: Int, message: String, error: String)
+  extends StreamingListenerEvent
+case class StreamingListenerReceiverStopped(streamId: Int, message: String, error: String)
+  extends StreamingListenerEvent
 
 /** An event used in the listener to shutdown the listener daemon thread. */
 private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent
@@ -41,14 +47,20 @@ trait StreamingListener {
   /** Called when a receiver has been started */
   def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { }
 
+  /** Called when a receiver has reported an error */
+  def onReceiverError(receiverError: StreamingListenerReceiverError) { }
+
+  /** Called when a receiver has been stopped */
+  def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped) { }
+
   /** Called when a batch of jobs has been submitted for processing. */
   def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { }
 
-  /** Called when processing of a batch of jobs has completed. */
-  def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { }
-
   /** Called when processing of a batch of jobs has started.  */
   def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { }
+
+  /** Called when processing of a batch of jobs has completed. */
+  def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
index ea03dfc..398724d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
@@ -40,6 +40,10 @@ private[spark] class StreamingListenerBus() extends Logging {
         event match {
           case receiverStarted: StreamingListenerReceiverStarted =>
             listeners.foreach(_.onReceiverStarted(receiverStarted))
+          case receiverError: StreamingListenerReceiverError =>
+            listeners.foreach(_.onReceiverError(receiverError))
+          case receiverStopped: StreamingListenerReceiverStopped =>
+            listeners.foreach(_.onReceiverStopped(receiverStopped))
           case batchSubmitted: StreamingListenerBatchSubmitted =>
             listeners.foreach(_.onBatchSubmitted(batchSubmitted))
           case batchStarted: StreamingListenerBatchStarted =>

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
index 8b025b0..bf637c1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
@@ -62,8 +62,8 @@ private[ui] class StreamingJobProgressListener(ssc: StreamingContext) extends
St
     totalCompletedBatches += 1L
   }
 
-  def numNetworkReceivers = synchronized {
-    ssc.graph.getNetworkInputStreams().size
+  def numReceivers = synchronized {
+    ssc.graph.getReceiverInputStreams().size
   }
 
   def numTotalCompletedBatches: Long = synchronized {
@@ -101,7 +101,7 @@ private[ui] class StreamingJobProgressListener(ssc: StreamingContext)
extends St
   def receivedRecordsDistributions: Map[Int, Option[Distribution]] = synchronized {
     val latestBatchInfos = retainedBatches.reverse.take(batchInfoLimit)
     val latestBlockInfos = latestBatchInfos.map(_.receivedBlockInfo)
-    (0 until numNetworkReceivers).map { receiverId =>
+    (0 until numReceivers).map { receiverId =>
       val blockInfoOfParticularReceiver = latestBlockInfos.map { batchInfo =>
         batchInfo.get(receiverId).getOrElse(Array.empty)
       }
@@ -117,11 +117,11 @@ private[ui] class StreamingJobProgressListener(ssc: StreamingContext)
extends St
   def lastReceivedBatchRecords: Map[Int, Long] = {
     val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo)
     lastReceivedBlockInfoOption.map { lastReceivedBlockInfo =>
-      (0 until numNetworkReceivers).map { receiverId =>
+      (0 until numReceivers).map { receiverId =>
         (receiverId, lastReceivedBlockInfo(receiverId).map(_.numRecords).sum)
       }.toMap
     }.getOrElse {
-      (0 until numNetworkReceivers).map(receiverId => (receiverId, 0L)).toMap
+      (0 until numReceivers).map(receiverId => (receiverId, 0L)).toMap
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
index 6607437..8fe1219 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
@@ -40,7 +40,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
     val content =
       generateBasicStats() ++ <br></br> ++
       <h4>Statistics over last {listener.retainedCompletedBatches.size} processed batches</h4>
++
-      generateNetworkStatsTable() ++
+      generateReceiverStats() ++
       generateBatchStatsTable()
     UIUtils.headerSparkPage(
       content, parent.basePath, parent.appName, "Streaming", parent.headerTabs, parent, Some(5000))
@@ -57,7 +57,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
         <strong>Time since start: </strong>{formatDurationVerbose(timeSinceStart)}
       </li>
       <li>
-        <strong>Network receivers: </strong>{listener.numNetworkReceivers}
+        <strong>Network receivers: </strong>{listener.numReceivers}
       </li>
       <li>
         <strong>Batch interval: </strong>{formatDurationVerbose(listener.batchDuration)}
@@ -71,8 +71,8 @@ private[ui] class StreamingPage(parent: StreamingTab)
     </ul>
   }
 
-  /** Generate stats of data received over the network the streaming program */
-  private def generateNetworkStatsTable(): Seq[Node] = {
+  /** Generate stats of data received by the receivers in the streaming program */
+  private def generateReceiverStats(): Seq[Node] = {
     val receivedRecordDistributions = listener.receivedRecordsDistributions
     val lastBatchReceivedRecord = listener.lastReceivedBatchRecords
     val table = if (receivedRecordDistributions.size > 0) {
@@ -86,13 +86,13 @@ private[ui] class StreamingPage(parent: StreamingTab)
         "75th percentile rate\n[records/sec]",
         "Maximum rate\n[records/sec]"
       )
-      val dataRows = (0 until listener.numNetworkReceivers).map { receiverId =>
+      val dataRows = (0 until listener.numReceivers).map { receiverId =>
         val receiverInfo = listener.receiverInfo(receiverId)
         val receiverName = receiverInfo.map(_.toString).getOrElse(s"Receiver-$receiverId")
         val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell)
-        val receiverLastBatchRecords = formatDurationVerbose(lastBatchReceivedRecord(receiverId))
+        val receiverLastBatchRecords = formatNumber(lastBatchReceivedRecord(receiverId))
         val receivedRecordStats = receivedRecordDistributions(receiverId).map { d =>
-          d.getQuantiles().map(r => formatDurationVerbose(r.toLong))
+          d.getQuantiles().map(r => formatNumber(r.toLong))
         }.getOrElse {
           Seq(emptyCell, emptyCell, emptyCell, emptyCell, emptyCell)
         }
@@ -104,8 +104,8 @@ private[ui] class StreamingPage(parent: StreamingTab)
     }
 
     val content =
-      <h5>Network Input Statistics</h5> ++
-      <div>{table.getOrElse("No network receivers")}</div>
+      <h5>Receiver Statistics</h5> ++
+      <div>{table.getOrElse("No receivers")}</div>
 
     content
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index e016377..1a616a0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -77,7 +77,9 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) =>
Unit, name:
   def stop(interruptTimer: Boolean): Long = synchronized {
     if (!stopped) {
       stopped = true
-      if (interruptTimer) thread.interrupt()
+      if (interruptTimer) {
+        thread.interrupt()
+      }
       thread.join()
       logInfo("Stopped timer for " + name + " after time " + prevTime)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index a0b1bbc..f9bfb9b 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.streaming;
 
+import org.apache.spark.streaming.api.java.*;
 import scala.Tuple2;
 
 import org.junit.Assert;
@@ -36,10 +37,6 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.*;
 import org.apache.spark.storage.StorageLevel;
-import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.streaming.api.java.JavaDStreamLike;
-import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaStreamingContext;
 
 // The test suite itself is Serializable so that anonymous Function implementations can be
 // serialized, as an alternative to converting these anonymous classes to static inner classes;
@@ -1668,7 +1665,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements
Serializa
   // InputStream functionality is deferred to the existing Scala tests.
   @Test
   public void testSocketTextStream() {
-    JavaDStream<String> test = ssc.socketTextStream("localhost", 12345);
+      JavaReceiverInputDStream<String> test = ssc.socketTextStream("localhost", 12345);
   }
 
   @Test
@@ -1701,6 +1698,6 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements
Serializa
 
   @Test
   public void testRawSocketStream() {
-    JavaDStream<String> test = ssc.rawSocketStream("localhost", 12345);
+    JavaReceiverInputDStream<String> test = ssc.rawSocketStream("localhost", 12345);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 952511d..46b7f63 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -36,10 +36,9 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.Logging
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.dstream.NetworkReceiver
-import org.apache.spark.streaming.receivers.Receiver
 import org.apache.spark.streaming.util.ManualClock
 import org.apache.spark.util.Utils
+import org.apache.spark.streaming.receiver.{ActorHelper, Receiver}
 
 class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
 
@@ -207,7 +206,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
 
     // set up the network stream using the test receiver
     val ssc = new StreamingContext(conf, batchDuration)
-    val networkStream = ssc.networkStream[Int](testReceiver)
+    val networkStream = ssc.receiverStream[Int](testReceiver)
     val countStream = networkStream.count
     val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
     val outputStream = new TestOutputStream(countStream, outputBuffer)
@@ -301,7 +300,7 @@ object TestServer {
 }
 
 /** This is an actor for testing actor input stream */
-class TestActor(port: Int) extends Actor with Receiver {
+class TestActor(port: Int) extends Actor with ActorHelper {
 
   def bytesToString(byteString: ByteString) = byteString.utf8String
 
@@ -309,24 +308,22 @@ class TestActor(port: Int) extends Actor with Receiver {
 
   def receive = {
     case IO.Read(socket, bytes) =>
-      pushBlock(bytesToString(bytes))
+      store(bytesToString(bytes))
   }
 }
 
 /** This is a receiver to test multiple threads inserting data using block generator */
 class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
-  extends NetworkReceiver[Int] {
+  extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging {
   lazy val executorPool = Executors.newFixedThreadPool(numThreads)
-  lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
   lazy val finishCount = new AtomicInteger(0)
 
-  protected def onStart() {
-    blockGenerator.start()
+  def onStart() {
     (1 to numThreads).map(threadId => {
       val runnable = new Runnable {
         def run() {
           (1 to numRecordsPerThread).foreach(i =>
-            blockGenerator += (threadId * numRecordsPerThread + i) )
+            store(threadId * numRecordsPerThread + i) )
           if (finishCount.incrementAndGet == numThreads) {
             MultiThreadTestReceiver.haveAllThreadsFinished = true
           }
@@ -337,7 +334,7 @@ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
     })
   }
 
-  protected def onStop() {
+  def onStop() {
     executorPool.shutdown()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
new file mode 100644
index 0000000..5c0415a
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
+import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver,
ReceiverSupervisor}
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+/** Testsuite for testing the network receiver behavior */
+class NetworkReceiverSuite extends FunSuite with Timeouts {
+
+  test("network receiver life cycle") {
+
+    val receiver = new FakeReceiver
+    val executor = new FakeReceiverSupervisor(receiver)
+
+    assert(executor.isAllEmpty)
+
+    // Thread that runs the executor
+    val executingThread = new Thread() {
+      override def run() {
+        executor.start()
+        executor.awaitTermination()
+      }
+    }
+
+    // Start the receiver
+    executingThread.start()
+
+    // Verify that the receiver
+    intercept[Exception] {
+      failAfter(200 millis) {
+        executingThread.join()
+      }
+    }
+
+    // Verify that receiver was started
+    assert(receiver.onStartCalled)
+    assert(executor.isReceiverStarted)
+    assert(receiver.isStarted)
+    assert(!receiver.isStopped())
+    assert(receiver.otherThread.isAlive)
+    eventually(timeout(100 millis), interval(10 millis)) {
+      assert(receiver.receiving)
+    }
+
+    // Verify whether the data stored by the receiver was sent to the executor
+    val byteBuffer = ByteBuffer.allocate(100)
+    val arrayBuffer = new ArrayBuffer[Int]()
+    val iterator = arrayBuffer.iterator
+    receiver.store(1)
+    receiver.store(byteBuffer)
+    receiver.store(arrayBuffer)
+    receiver.store(iterator)
+    assert(executor.singles.size === 1)
+    assert(executor.singles.head === 1)
+    assert(executor.byteBuffers.size === 1)
+    assert(executor.byteBuffers.head.eq(byteBuffer))
+    assert(executor.iterators.size === 1)
+    assert(executor.iterators.head.eq(iterator))
+    assert(executor.arrayBuffers.size === 1)
+    assert(executor.arrayBuffers.head.eq(arrayBuffer))
+
+    // Verify whether the exceptions reported by the receiver was sent to the executor
+    val exception = new Exception
+    receiver.reportError("Error", exception)
+    assert(executor.errors.size === 1)
+    assert(executor.errors.head.eq(exception))
+
+    // Verify restarting actually stops and starts the receiver
+    receiver.restart("restarting", null, 100)
+    assert(receiver.isStopped)
+    assert(receiver.onStopCalled)
+    eventually(timeout(1000 millis), interval(100 millis)) {
+      assert(receiver.onStartCalled)
+      assert(executor.isReceiverStarted)
+      assert(receiver.isStarted)
+      assert(!receiver.isStopped)
+      assert(receiver.receiving)
+    }
+
+    // Verify that stopping actually stops the thread
+    failAfter(100 millis) {
+      receiver.stop("test")
+      assert(receiver.isStopped)
+      assert(!receiver.otherThread.isAlive)
+
+      // The thread that started the executor should complete
+      // as stop() stops everything
+      executingThread.join()
+    }
+  }
+
+  test("block generator") {
+    val blockGeneratorListener = new FakeBlockGeneratorListener
+    val blockInterval = 200
+    val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString)
+    val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
+    val expectedBlocks = 5
+    val waitTime = expectedBlocks * blockInterval + (blockInterval / 2)
+    val generatedData = new ArrayBuffer[Int]
+
+    // Generate blocks
+    val startTime = System.currentTimeMillis()
+    blockGenerator.start()
+    var count = 0
+    while(System.currentTimeMillis - startTime < waitTime) {
+      blockGenerator += count
+      generatedData += count
+      count += 1
+      Thread.sleep(10)
+    }
+    blockGenerator.stop()
+
+    val recordedData = blockGeneratorListener.arrayBuffers.flatten
+    assert(blockGeneratorListener.arrayBuffers.size > 0)
+    assert(recordedData.toSet === generatedData.toSet)
+  }
+
+  /**
+   * An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
+   */
+  class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+    var otherThread: Thread = null
+    var receiving = false
+    var onStartCalled = false
+    var onStopCalled = false
+
+    def onStart() {
+      otherThread = new Thread() {
+        override def run() {
+          receiving = true
+          while(!isStopped()) {
+            Thread.sleep(10)
+          }
+        }
+      }
+      onStartCalled = true
+      otherThread.start()
+
+    }
+
+    def onStop() {
+      onStopCalled = true
+      otherThread.join()
+    }
+
+    def reset() {
+      receiving = false
+      onStartCalled = false
+      onStopCalled = false
+    }
+  }
+
+  /**
+   * An implementation of NetworkReceiverExecutor used for testing a NetworkReceiver.
+   * Instead of storing the data in the BlockManager, it stores all the data in a local buffer
+   * that can used for verifying that the data has been forwarded correctly.
+   */
+  class FakeReceiverSupervisor(receiver: FakeReceiver)
+    extends ReceiverSupervisor(receiver, new SparkConf()) {
+    val singles = new ArrayBuffer[Any]
+    val byteBuffers = new ArrayBuffer[ByteBuffer]
+    val iterators = new ArrayBuffer[Iterator[_]]
+    val arrayBuffers = new ArrayBuffer[ArrayBuffer[_]]
+    val errors = new ArrayBuffer[Throwable]
+
+    /** Check if all data structures are clean */
+    def isAllEmpty = {
+      singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty &&
+        arrayBuffers.isEmpty && errors.isEmpty
+    }
+
+    def pushSingle(data: Any) {
+      singles += data
+    }
+
+    def pushBytes(
+        bytes: ByteBuffer,
+        optionalMetadata: Option[Any],
+        optionalBlockId: Option[StreamBlockId]
+      ) {
+      byteBuffers += bytes
+    }
+
+    def pushIterator(
+        iterator: Iterator[_],
+        optionalMetadata: Option[Any],
+        optionalBlockId: Option[StreamBlockId]
+      ) {
+      iterators += iterator
+    }
+
+    def pushArrayBuffer(
+        arrayBuffer: ArrayBuffer[_],
+        optionalMetadata: Option[Any],
+        optionalBlockId: Option[StreamBlockId]
+      ) {
+      arrayBuffers +=  arrayBuffer
+    }
+
+    def reportError(message: String, throwable: Throwable) {
+      errors += throwable
+    }
+  }
+
+  /**
+   * An implementation of BlockGeneratorListener that is used to test the BlockGenerator.
+   */
+  class FakeBlockGeneratorListener(pushDelay: Long = 0) extends BlockGeneratorListener {
+    // buffer of data received as ArrayBuffers
+    val arrayBuffers = new ArrayBuffer[ArrayBuffer[Int]]
+    val errors = new ArrayBuffer[Throwable]
+
+    def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) {
+      val bufferOfInts = arrayBuffer.map(_.asInstanceOf[Int])
+      arrayBuffers += bufferOfInts
+      Thread.sleep(0)
+    }
+
+    def onError(message: String, throwable: Throwable) {
+      errors += throwable
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index ad5367a..6d14b1f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver}
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.receiver.Receiver
 import org.apache.spark.util.{MetadataCleaner, Utils}
 import org.scalatest.{BeforeAndAfter, FunSuite}
 import org.scalatest.concurrent.Timeouts
@@ -181,15 +182,15 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with
Timeouts w
     conf.set("spark.cleaner.ttl", "3600")
     sc = new SparkContext(conf)
     for (i <- 1 to 4) {
-      logInfo("==================================")
-      ssc = new StreamingContext(sc, batchDuration)
+      logInfo("==================================\n\n\n")
+      ssc = new StreamingContext(sc, Milliseconds(100))
       var runningCount = 0
       TestReceiver.counter.set(1)
       val input = ssc.networkStream(new TestReceiver)
       input.count.foreachRDD(rdd => {
         val count = rdd.first()
-        logInfo("Count = " + count)
         runningCount += count.toInt
+        logInfo("Count = " + count + ", Running count = " + runningCount)
       })
       ssc.start()
       ssc.awaitTermination(500)
@@ -216,12 +217,12 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with
Timeouts w
       ssc.start()
     }
 
-    // test whether waitForStop() exits after give amount of time
+    // test whether awaitTermination() exits after give amount of time
     failAfter(1000 millis) {
       ssc.awaitTermination(500)
     }
 
-    // test whether waitForStop() does not exit if not time is given
+    // test whether awaitTermination() does not exit if not time is given
     val exception = intercept[Exception] {
       failAfter(1000 millis) {
         ssc.awaitTermination()
@@ -276,23 +277,26 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with
Timeouts w
 class TestException(msg: String) extends Exception(msg)
 
 /** Custom receiver for testing whether all data received by a receiver gets processed or
not */
-class TestReceiver extends NetworkReceiver[Int] {
-  protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
-  protected def onStart() {
-    blockGenerator.start()
-    logInfo("BlockGenerator started on thread " + receivingThread)
-    try {
-      while(true) {
-        blockGenerator += TestReceiver.counter.getAndIncrement
-        Thread.sleep(0)
+class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging {
+
+  var receivingThreadOption: Option[Thread] = None
+
+  def onStart() {
+    val thread = new Thread() {
+      override def run() {
+        logInfo("Receiving started")
+        while (!isStopped) {
+          store(TestReceiver.counter.getAndIncrement)
+        }
+        logInfo("Receiving stopped at count value of " + TestReceiver.counter.get())
       }
-    } finally {
-      logInfo("Receiving stopped at count value of " + TestReceiver.counter.get())
     }
+    receivingThreadOption = Some(thread)
+    thread.start()
   }
 
-  protected def onStop() {
-    blockGenerator.stop()
+  def onStop() {
+    // no cleanup to be done, the receiving thread should stop on it own
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/04c37b6f/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 9e0f2c9..542c697 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -17,10 +17,19 @@
 
 package org.apache.spark.streaming
 
-import org.apache.spark.streaming.scheduler._
 import scala.collection.mutable.ArrayBuffer
-import org.scalatest.matchers.ShouldMatchers
+import scala.concurrent.Future
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.receiver.Receiver
+import org.apache.spark.streaming.scheduler._
+
+import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+import org.apache.spark.Logging
 
 class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
 
@@ -32,7 +41,7 @@ class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
   override def batchDuration = Milliseconds(100)
   override def actuallyWait = true
 
-  test("basic BatchInfo generation") {
+  test("batch info reporting") {
     val ssc = setupStreams(input, operation)
     val collector = new BatchInfoCollector
     ssc.addStreamingListener(collector)
@@ -54,6 +63,31 @@ class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers
{
     isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true)
   }
 
+  test("receiver info reporting") {
+    val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
+    val inputStream = ssc.networkStream(new StreamingListenerSuiteReceiver)
+    inputStream.foreachRDD(_.count)
+
+    val collector = new ReceiverInfoCollector
+    ssc.addStreamingListener(collector)
+
+    ssc.start()
+    try {
+      eventually(timeout(1000 millis), interval(20 millis)) {
+        collector.startedReceiverInfo should have size 1
+        collector.startedReceiverInfo(0).streamId should equal (0)
+        collector.stoppedReceiverStreamIds should have size 1
+        collector.stoppedReceiverStreamIds(0) should equal (0)
+        collector.receiverErrors should have size 1
+        collector.receiverErrors(0)._1 should equal (0)
+        collector.receiverErrors(0)._2 should include ("report error")
+        collector.receiverErrors(0)._3 should include ("report exception")
+      }
+    } finally {
+      ssc.stop()
+    }
+  }
+
   /** Check if a sequence of numbers is in increasing order */
   def isInIncreasingOrder(seq: Seq[Long]): Boolean = {
     for(i <- 1 until seq.size) {
@@ -61,12 +95,46 @@ class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers
{
     }
     true
   }
+}
+
+/** Listener that collects information on processed batches */
+class BatchInfoCollector extends StreamingListener {
+  val batchInfos = new ArrayBuffer[BatchInfo]
+  override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+    batchInfos += batchCompleted.batchInfo
+  }
+}
+
+/** Listener that collects information on processed batches */
+class ReceiverInfoCollector extends StreamingListener {
+  val startedReceiverInfo = new ArrayBuffer[ReceiverInfo]
+  val stoppedReceiverStreamIds = new ArrayBuffer[Int]()
+  val receiverErrors = new ArrayBuffer[(Int, String, String)]()
+
+  override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
+    startedReceiverInfo += receiverStarted.receiverInfo
+  }
+
+  override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped) {
+    stoppedReceiverStreamIds += receiverStopped.streamId
+  }
+
+  override def onReceiverError(receiverError: StreamingListenerReceiverError) {
+    receiverErrors += ((receiverError.streamId, receiverError.message, receiverError.error))
+  }
+}
 
-  /** Listener that collects information on processed batches */
-  class BatchInfoCollector extends StreamingListener {
-    val batchInfos = new ArrayBuffer[BatchInfo]
-    override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
-      batchInfos += batchCompleted.batchInfo
+class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_ONLY) with
Logging {
+  def onStart() {
+    Future {
+      logInfo("Started receiver and sleeping")
+      Thread.sleep(10)
+      logInfo("Reporting error and sleeping")
+      reportError("test report error", new Exception("test report exception"))
+      Thread.sleep(10)
+      logInfo("Stopping")
+      stop("test stop error")
     }
   }
+  def onStop() { }
 }


Mime
View raw message