spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From shiva...@apache.org
Subject git commit: [SPARK-4031] Make torrent broadcast read blocks on use.
Date Tue, 28 Oct 2014 17:14:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0ac52e305 -> 7768a800d


[SPARK-4031] Make torrent broadcast read blocks on use.

This avoids reading torrent broadcast variables when they are referenced in the closure but
not used in the closure. This is done by using a `lazy val` to read broadcast blocks

cc rxin JoshRosen for review

Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu>

Closes #2871 from shivaram/broadcast-read-value and squashes the following commits:

1456d65 [Shivaram Venkataraman] Use getUsedTimeMs and remove readObject
d6c5ee9 [Shivaram Venkataraman] Use laxy val to implement readBroadcastBlock
0b34df7 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into
broadcast-read-value
9cec507 [Shivaram Venkataraman] Test if broadcast variables are read lazily
768b40b [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into
broadcast-read-value
8792ed8 [Shivaram Venkataraman] Make torrent broadcast read blocks on use. This avoids reading
broadcast variables when they are referenced in the closure but not used by the code.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7768a800
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7768a800
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7768a800

Branch: refs/heads/master
Commit: 7768a800d4c4c19d76cef1ee40af6900bbac821c
Parents: 0ac52e3
Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu>
Authored: Tue Oct 28 10:14:16 2014 -0700
Committer: Shivaram Venkataraman <shivaram@cs.berkeley.edu>
Committed: Tue Oct 28 10:14:16 2014 -0700

----------------------------------------------------------------------
 .../spark/broadcast/TorrentBroadcast.scala      | 43 +++++++++++---------
 .../scala/org/apache/spark/util/Utils.scala     | 15 +++++++
 .../apache/spark/broadcast/BroadcastSuite.scala | 30 +++++++++++++-
 3 files changed, 67 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7768a800/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 75e64c1..94142d3 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
   extends Broadcast[T](id) with Logging with Serializable {
 
   /**
-   * Value of the broadcast object. On driver, this is set directly by the constructor.
-   * On executors, this is reconstructed by [[readObject]], which builds this value by reading
-   * blocks from the driver and/or other executors.
+   * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
+   * which builds this value by reading blocks from the driver and/or other executors.
+   *
+   * On the driver, if the value is required, it is read lazily from the block manager.
    */
-  @transient private var _value: T = obj
+  @transient private lazy val _value: T = readBroadcastBlock()
+
   /** The compression codec to use, or None if compression is disabled */
   @transient private var compressionCodec: Option[CompressionCodec] = _
   /** Size of each block. Default value is 4MB.  This value is only read by the broadcaster.
*/
@@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
   private val broadcastId = BroadcastBlockId(id)
 
   /** Total number of blocks this broadcast variable contains. */
-  private val numBlocks: Int = writeBlocks()
+  private val numBlocks: Int = writeBlocks(obj)
 
-  override protected def getValue() = _value
+  override protected def getValue() = {
+    _value
+  }
 
   /**
    * Divide the object into multiple blocks and put those blocks in the block manager.
-   *
+   * @param value the object to divide
    * @return number of blocks this broadcast variable is divided into
    */
-  private def writeBlocks(): Int = {
+  private def writeBlocks(value: T): Int = {
     // Store a copy of the broadcast variable in the driver so that tasks run on the driver
     // do not create a duplicate copy of the broadcast variable's value.
-    SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK,
+    SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
       tellMaster = false)
     val blocks =
-      TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
+      TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
     blocks.zipWithIndex.foreach { case (block, i) =>
       SparkEnv.get.blockManager.putBytes(
         BroadcastBlockId(id, "piece" + i),
@@ -157,31 +161,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
     out.defaultWriteObject()
   }
 
-  /** Used by the JVM when deserializing this object. */
-  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
-    in.defaultReadObject()
+  private def readBroadcastBlock(): T = Utils.tryOrIOException {
     TorrentBroadcast.synchronized {
       setConf(SparkEnv.get.conf)
       SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
         case Some(x) =>
-          _value = x.asInstanceOf[T]
+          x.asInstanceOf[T]
 
         case None =>
           logInfo("Started reading broadcast variable " + id)
-          val start = System.nanoTime()
+          val startTimeMs = System.currentTimeMillis()
           val blocks = readBlocks()
-          val time = (System.nanoTime() - start) / 1e9
-          logInfo("Reading broadcast variable " + id + " took " + time + " s")
+          logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
 
-          _value =
-            TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
+          val obj = TorrentBroadcast.unBlockifyObject[T](
+            blocks, SparkEnv.get.serializer, compressionCodec)
           // Store the merged copy in BlockManager so other tasks on this executor don't
           // need to re-fetch it.
           SparkEnv.get.blockManager.putSingle(
-            broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+            broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+          obj
       }
     }
   }
+
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7768a800/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 4660030..612eca3 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -988,6 +988,21 @@ private[spark] object Utils extends Logging {
     }
   }
 
+  /**
+   * Execute a block of code that returns a value, re-throwing any non-fatal uncaught
+   * exceptions as IOException. This is used when implementing Externalizable and Serializable's
+   * read and write methods, since Java's serializer will not report non-IOExceptions properly;
+   * see SPARK-4080 for more context.
+   */
+  def tryOrIOException[T](block: => T): T = {
+    try {
+      block
+    } catch {
+      case e: IOException => throw e
+      case NonFatal(t) => throw new IOException(t)
+    }
+  }
+
   /** Default filtering function for finding call sites using `getCallSite`. */
   private def coreExclusionFunction(className: String): Boolean = {
     // A regular expression to match classes of the "core" Spark API that we want to skip
when

http://git-wip-us.apache.org/repos/asf/spark/blob/7768a800/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 1014fd6..b0a70f0 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -21,11 +21,28 @@ import scala.util.Random
 
 import org.scalatest.{Assertions, FunSuite}
 
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv}
 import org.apache.spark.io.SnappyCompressionCodec
+import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.storage._
 
+// Dummy class that creates a broadcast variable but doesn't use it
+class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
+  @transient val list = List(1, 2, 3, 4)
+  val broadcast = rdd.context.broadcast(list)
+  val bid = broadcast.id
+
+  def doSomething() = {
+    rdd.map { x =>
+      val bm = SparkEnv.get.blockManager
+      // Check if broadcast block was fetched
+      val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined
+      (x, isFound)
+    }.collect().toSet
+  }
+}
+
 class BroadcastSuite extends FunSuite with LocalSparkContext {
 
   private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -105,6 +122,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
     }
   }
 
+  test("Test Lazy Broadcast variables with TorrentBroadcast") {
+    val numSlaves = 2
+    val conf = torrentConf.clone
+    sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
+    val rdd = sc.parallelize(1 to numSlaves)
+
+    val results = new DummyBroadcastClass(rdd).doSomething()
+
+    assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet)
+  }
+
   test("Unpersisting HttpBroadcast on executors only in local mode") {
     testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message