spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject [2/3] [SPARK-1103] Automatic garbage collection of RDD, shuffle and broadcast data
Date Tue, 08 Apr 2014 06:41:00 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 19138d9..b021564 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -19,20 +19,22 @@ package org.apache.spark.storage
 
 import java.io.{File, InputStream, OutputStream}
 import java.nio.{ByteBuffer, MappedByteBuffer}
+
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.concurrent.{Await, Future}
 import scala.concurrent.duration._
 import scala.util.Random
+
 import akka.actor.{ActorSystem, Cancellable, Props}
 import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
 import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
+
+import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.network._
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.util._
 
-
 sealed trait Values
 
 case class ByteBufferValues(buffer: ByteBuffer) extends Values
@@ -46,7 +48,8 @@ private[spark] class BlockManager(
     val defaultSerializer: Serializer,
     maxMemory: Long,
     val conf: SparkConf,
-    securityManager: SecurityManager)
+    securityManager: SecurityManager,
+    mapOutputTracker: MapOutputTracker)
   extends Logging {
 
   val shuffleBlockManager = new ShuffleBlockManager(this)
@@ -55,7 +58,7 @@ private[spark] class BlockManager(
 
   private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
 
-  private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+  private[storage] val memoryStore = new MemoryStore(this, maxMemory)
   private[storage] val diskStore = new DiskStore(this, diskBlockManager)
   var tachyonInitialized = false
   private[storage] lazy val tachyonStore: TachyonStore = {
@@ -98,7 +101,7 @@ private[spark] class BlockManager(
 
   val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
 
-  val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+  val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
     name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
 
   // Pending re-registration action being executed asynchronously or null if none
@@ -137,9 +140,10 @@ private[spark] class BlockManager(
       master: BlockManagerMaster,
       serializer: Serializer,
       conf: SparkConf,
-      securityManager: SecurityManager) = {
+      securityManager: SecurityManager,
+      mapOutputTracker: MapOutputTracker) = {
     this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
-      conf, securityManager)
+      conf, securityManager, mapOutputTracker)
   }
 
   /**
@@ -217,9 +221,26 @@ private[spark] class BlockManager(
   }
 
   /**
-   * Get storage level of local block. If no info exists for the block, then returns null.
+   * Get the BlockStatus for the block identified by the given ID, if it exists.
+   * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
+   */
+  def getStatus(blockId: BlockId): Option[BlockStatus] = {
+    blockInfo.get(blockId).map { info =>
+      val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
+      val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L
+      // Assume that block is not in Tachyon
+      BlockStatus(info.level, memSize, diskSize, 0L)
+    }
+  }
+
+  /**
+   * Get the ids of existing blocks that match the given filter. Note that this will
+   * query the blocks stored in the disk block manager (that the block manager
+   * may not know of).
    */
-  def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+  def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = {
+    (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq
+  }
 
   /**
    * Tell the master about the current storage status of a block. This will send a block update
@@ -525,9 +546,8 @@ private[spark] class BlockManager(
 
   /**
    * A short circuited method to get a block writer that can write data directly to disk.
-   * The Block will be appended to the File specified by filename.
-   * This is currently used for writing shuffle files out. Callers should handle error
-   * cases.
+   * The Block will be appended to the File specified by filename. This is currently used for
+   * writing shuffle files out. Callers should handle error cases.
    */
   def getDiskWriter(
       blockId: BlockId,
@@ -863,11 +883,22 @@ private[spark] class BlockManager(
    * @return The number of blocks removed.
    */
   def removeRdd(rddId: Int): Int = {
-    // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
-    // from RDD.id to blocks.
+    // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
     logInfo("Removing RDD " + rddId)
     val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
-    blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
+    blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
+    blocksToRemove.size
+  }
+
+  /**
+   * Remove all blocks belonging to the given broadcast.
+   */
+  def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
+    logInfo("Removing broadcast " + broadcastId)
+    val blocksToRemove = blockInfo.keys.collect {
+      case bid @ BroadcastBlockId(`broadcastId`, _) => bid
+    }
+    blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
     blocksToRemove.size
   }
 
@@ -908,10 +939,10 @@ private[spark] class BlockManager(
   }
 
   private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
-    val iterator = blockInfo.internalMap.entrySet().iterator()
+    val iterator = blockInfo.getEntrySet.iterator
     while (iterator.hasNext) {
       val entry = iterator.next()
-      val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+      val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
       if (time < cleanupTime && shouldDrop(id)) {
         info.synchronized {
           val level = info.level
@@ -935,7 +966,7 @@ private[spark] class BlockManager(
 
   def shouldCompress(blockId: BlockId): Boolean = blockId match {
     case ShuffleBlockId(_, _, _) => compressShuffle
-    case BroadcastBlockId(_) => compressBroadcast
+    case BroadcastBlockId(_, _) => compressBroadcast
     case RDDBlockId(_, _) => compressRdds
     case TempBlockId(_) => compressShuffleSpill
     case _ => false

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 4bc1b40..7897fad 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -81,6 +81,14 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
     askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
   }
 
+  /**
+   * Check if block manager master has a block. Note that this can be used to check for only
+   * those blocks that are reported to block manager master.
+   */
+  def contains(blockId: BlockId) = {
+    !getLocations(blockId).isEmpty
+  }
+
   /** Get ids of other nodes in the cluster from the driver */
   def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
     val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
@@ -99,12 +107,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
     askDriverWithReply(RemoveBlock(blockId))
   }
 
-  /**
-   * Remove all blocks belonging to the given RDD.
-   */
+  /** Remove all blocks belonging to the given RDD. */
   def removeRdd(rddId: Int, blocking: Boolean) {
     val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
-    future onFailure {
+    future.onFailure {
       case e: Throwable => logError("Failed to remove RDD " + rddId, e)
     }
     if (blocking) {
@@ -112,6 +118,31 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
     }
   }
 
+  /** Remove all blocks belonging to the given shuffle. */
+  def removeShuffle(shuffleId: Int, blocking: Boolean) {
+    val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+    future.onFailure {
+      case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
+    }
+    if (blocking) {
+      Await.result(future, timeout)
+    }
+  }
+
+  /** Remove all blocks belonging to the given broadcast. */
+  def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
+    val future = askDriverWithReply[Future[Seq[Int]]](
+      RemoveBroadcast(broadcastId, removeFromMaster))
+    future.onFailure {
+      case e: Throwable =>
+        logError("Failed to remove broadcast " + broadcastId +
+          " with removeFromMaster = " + removeFromMaster, e)
+    }
+    if (blocking) {
+      Await.result(future, timeout)
+    }
+  }
+
   /**
    * Return the memory status for each block manager, in the form of a map from
    * the block manager's id to two long values. The first value is the maximum
@@ -126,6 +157,51 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
     askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
   }
 
+  /**
+   * Return the block's status on all block managers, if any. NOTE: This is a
+   * potentially expensive operation and should only be used for testing.
+   *
+   * If askSlaves is true, this invokes the master to query each block manager for the most
+   * updated block statuses. This is useful when the master is not informed of the given block
+   * by all block managers.
+   */
+  def getBlockStatus(
+      blockId: BlockId,
+      askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
+    val msg = GetBlockStatus(blockId, askSlaves)
+    /*
+     * To avoid potential deadlocks, the use of Futures is necessary, because the master actor
+     * should not block on waiting for a block manager, which can in turn be waiting for the
+     * master actor for a response to a prior message.
+     */
+    val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+    val (blockManagerIds, futures) = response.unzip
+    val result = Await.result(Future.sequence(futures), timeout)
+    if (result == null) {
+      throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
+    }
+    val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]]
+    blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) =>
+      status.map { s => (blockManagerId, s) }
+    }.toMap
+  }
+
+  /**
+   * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This
+   * is a potentially expensive operation and should only be used for testing.
+   *
+   * If askSlaves is true, this invokes the master to query each block manager for the most
+   * updated block statuses. This is useful when the master is not informed of the given block
+   * by all block managers.
+   */
+  def getMatchingBlockIds(
+      filter: BlockId => Boolean,
+      askSlaves: Boolean): Seq[BlockId] = {
+    val msg = GetMatchingBlockIds(filter, askSlaves)
+    val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
+    Await.result(future, timeout)
+  }
+
   /** Stop the driver actor, called only on the Spark driver node */
   def stop() {
     if (driverActor != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 378f4ca..c57b6e8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -94,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
     case GetStorageStatus =>
       sender ! storageStatus
 
+    case GetBlockStatus(blockId, askSlaves) =>
+      sender ! blockStatus(blockId, askSlaves)
+
+    case GetMatchingBlockIds(filter, askSlaves) =>
+      sender ! getMatchingBlockIds(filter, askSlaves)
+
     case RemoveRdd(rddId) =>
       sender ! removeRdd(rddId)
 
+    case RemoveShuffle(shuffleId) =>
+      sender ! removeShuffle(shuffleId)
+
+    case RemoveBroadcast(broadcastId, removeFromDriver) =>
+      sender ! removeBroadcast(broadcastId, removeFromDriver)
+
     case RemoveBlock(blockId) =>
       removeBlockFromWorkers(blockId)
       sender ! true
@@ -140,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
     // The dispatcher is used as an implicit argument into the Future sequence construction.
     import context.dispatcher
     val removeMsg = RemoveRdd(rddId)
-    Future.sequence(blockManagerInfo.values.map { bm =>
-      bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
-    }.toSeq)
+    Future.sequence(
+      blockManagerInfo.values.map { bm =>
+        bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+      }.toSeq
+    )
+  }
+
+  private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
+    // Nothing to do in the BlockManagerMasterActor data structures
+    import context.dispatcher
+    val removeMsg = RemoveShuffle(shuffleId)
+    Future.sequence(
+      blockManagerInfo.values.map { bm =>
+        bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
+      }.toSeq
+    )
+  }
+
+  /**
+   * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
+   * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
+   * from the executors, but not from the driver.
+   */
+  private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
+    // TODO: Consolidate usages of <driver>
+    import context.dispatcher
+    val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
+    val requiredBlockManagers = blockManagerInfo.values.filter { info =>
+      removeFromDriver || info.blockManagerId.executorId != "<driver>"
+    }
+    Future.sequence(
+      requiredBlockManagers.map { bm =>
+        bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+      }.toSeq
+    )
   }
 
   private def removeBlockManager(blockManagerId: BlockManagerId) {
@@ -225,6 +269,61 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
     }.toArray
   }
 
+  /**
+   * Return the block's status for all block managers, if any. NOTE: This is a
+   * potentially expensive operation and should only be used for testing.
+   *
+   * If askSlaves is true, the master queries each block manager for the most updated block
+   * statuses. This is useful when the master is not informed of the given block by all block
+   * managers.
+   */
+  private def blockStatus(
+      blockId: BlockId,
+      askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
+    import context.dispatcher
+    val getBlockStatus = GetBlockStatus(blockId)
+    /*
+     * Rather than blocking on the block status query, master actor should simply return
+     * Futures to avoid potential deadlocks. This can arise if there exists a block manager
+     * that is also waiting for this master actor's response to a previous message.
+     */
+    blockManagerInfo.values.map { info =>
+      val blockStatusFuture =
+        if (askSlaves) {
+          info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
+        } else {
+          Future { info.getStatus(blockId) }
+        }
+      (info.blockManagerId, blockStatusFuture)
+    }.toMap
+  }
+
+  /**
+   * Return the ids of blocks present in all the block managers that match the given filter.
+   * NOTE: This is a potentially expensive operation and should only be used for testing.
+   *
+   * If askSlaves is true, the master queries each block manager for the most updated block
+   * statuses. This is useful when the master is not informed of the given block by all block
+   * managers.
+   */
+  private def getMatchingBlockIds(
+      filter: BlockId => Boolean,
+      askSlaves: Boolean): Future[Seq[BlockId]] = {
+    import context.dispatcher
+    val getMatchingBlockIds = GetMatchingBlockIds(filter)
+    Future.sequence(
+      blockManagerInfo.values.map { info =>
+        val future =
+          if (askSlaves) {
+            info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
+          } else {
+            Future { info.blocks.keys.filter(filter).toSeq }
+          }
+        future
+      }
+    ).map(_.flatten.toSeq)
+  }
+
   private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
     if (!blockManagerInfo.contains(id)) {
       blockManagerIdByExecutor.get(id.executorId) match {
@@ -334,6 +433,8 @@ private[spark] class BlockManagerInfo(
   logInfo("Registering block manager %s with %s RAM".format(
     blockManagerId.hostPort, Utils.bytesToString(maxMem)))
 
+  def getStatus(blockId: BlockId) = Option(_blocks.get(blockId))
+
   def updateLastSeenMs() {
     _lastSeenMs = System.currentTimeMillis()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 8a36b5c..2b53bf3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages {
   // Remove all blocks belonging to a specific RDD.
   case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
 
+  // Remove all blocks belonging to a specific shuffle.
+  case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave
+
+  // Remove all blocks belonging to a specific broadcast.
+  case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
+    extends ToBlockManagerSlave
+
 
   //////////////////////////////////////////////////////////////////////////////////
   // Messages from slaves to the master.
@@ -80,7 +87,8 @@ private[storage] object BlockManagerMessages {
   }
 
   object UpdateBlockInfo {
-    def apply(blockManagerId: BlockManagerId,
+    def apply(
+        blockManagerId: BlockManagerId,
         blockId: BlockId,
         storageLevel: StorageLevel,
         memSize: Long,
@@ -108,7 +116,13 @@ private[storage] object BlockManagerMessages {
 
   case object GetMemoryStatus extends ToBlockManagerMaster
 
-  case object ExpireDeadHosts extends ToBlockManagerMaster
-
   case object GetStorageStatus extends ToBlockManagerMaster
+
+  case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true)
+    extends ToBlockManagerMaster
+
+  case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true)
+    extends ToBlockManagerMaster
+
+  case object ExpireDeadHosts extends ToBlockManagerMaster
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index bcfb82d..6d4db06 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.storage
 
-import akka.actor.Actor
+import scala.concurrent.Future
 
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark.{Logging, MapOutputTracker}
 import org.apache.spark.storage.BlockManagerMessages._
 
 /**
@@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._
  * this is used to remove blocks from the slave's BlockManager.
  */
 private[storage]
-class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
-  override def receive = {
+class BlockManagerSlaveActor(
+    blockManager: BlockManager,
+    mapOutputTracker: MapOutputTracker)
+  extends Actor with Logging {
+
+  import context.dispatcher
 
+  // Operations that involve removing blocks may be slow and should be done asynchronously
+  override def receive = {
     case RemoveBlock(blockId) =>
-      blockManager.removeBlock(blockId)
+      doAsync[Boolean]("removing block " + blockId, sender) {
+        blockManager.removeBlock(blockId)
+        true
+      }
 
     case RemoveRdd(rddId) =>
-      val numBlocksRemoved = blockManager.removeRdd(rddId)
-      sender ! numBlocksRemoved
+      doAsync[Int]("removing RDD " + rddId, sender) {
+        blockManager.removeRdd(rddId)
+      }
+
+    case RemoveShuffle(shuffleId) =>
+      doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
+        if (mapOutputTracker != null) {
+          mapOutputTracker.unregisterShuffle(shuffleId)
+        }
+        blockManager.shuffleBlockManager.removeShuffle(shuffleId)
+      }
+
+    case RemoveBroadcast(broadcastId, tellMaster) =>
+      doAsync[Int]("removing broadcast " + broadcastId, sender) {
+        blockManager.removeBroadcast(broadcastId, tellMaster)
+      }
+
+    case GetBlockStatus(blockId, _) =>
+      sender ! blockManager.getStatus(blockId)
+
+    case GetMatchingBlockIds(filter, _) =>
+      sender ! blockManager.getMatchingBlockIds(filter)
+  }
+
+  private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
+    val future = Future {
+      logDebug(actionMessage)
+      body
+    }
+    future.onSuccess { case response =>
+      logDebug("Done " + actionMessage + ", response is " + response)
+      responseActor ! response
+      logDebug("Sent response: " + response + " to " + responseActor)
+    }
+    future.onFailure { case t: Throwable =>
+      logError("Error in " + actionMessage, t)
+      responseActor ! null.asInstanceOf[T]
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f3e1c38..7a24c8f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
 
   def getFile(blockId: BlockId): File = getFile(blockId.name)
 
+  /** Check if disk block manager has a block. */
+  def containsBlock(blockId: BlockId): Boolean = {
+    getBlockLocation(blockId).file.exists()
+  }
+
+  /** List all the blocks currently stored on disk by the disk manager. */
+  def getAllBlocks(): Seq[BlockId] = {
+    // Get all the files inside the array of array of directories
+    subDirs.flatten.filter(_ != null).flatMap { dir =>
+      val files = dir.list()
+      if (files != null) files else Seq.empty
+    }.map(BlockId.apply)
+  }
+
   /** Produces a unique block id and File suitable for intermediate results. */
   def createTempBlock(): (TempBlockId, File) = {
     var blockId = new TempBlockId(UUID.randomUUID())

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index bb07c8c..4cd4cdb 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -169,23 +169,43 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
     throw new IllegalStateException("Failed to find shuffle block: " + id)
   }
 
+  /** Remove all the blocks / files and metadata related to a particular shuffle. */
+  def removeShuffle(shuffleId: ShuffleId): Boolean = {
+    // Do not change the ordering of this, if shuffleStates should be removed only
+    // after the corresponding shuffle blocks have been removed
+    val cleaned = removeShuffleBlocks(shuffleId)
+    shuffleStates.remove(shuffleId)
+    cleaned
+  }
+
+  /** Remove all the blocks / files related to a particular shuffle. */
+  private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
+    shuffleStates.get(shuffleId) match {
+      case Some(state) =>
+        if (consolidateShuffleFiles) {
+          for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+            file.delete()
+          }
+        } else {
+          for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+            val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+            blockManager.diskBlockManager.getFile(blockId).delete()
+          }
+        }
+        logInfo("Deleted all files for shuffle " + shuffleId)
+        true
+      case None =>
+        logInfo("Could not find files for shuffle " + shuffleId + " for deleting")
+        false
+    }
+  }
+
   private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
     "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
   }
 
   private def cleanup(cleanupTime: Long) {
-    shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
-      if (consolidateShuffleFiles) {
-        for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
-          file.delete()
-        }
-      } else {
-        for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
-          val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
-          blockManager.diskBlockManager.getFile(blockId).delete()
-        }
-      }
-    })
+    shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 226ed2a..a107c51 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.ArrayBlockingQueue
 import akka.actor._
 import util.Random
 
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.KryoSerializer
 
@@ -48,7 +48,7 @@ private[spark] object ThreadingTest {
         val block = (1 to blockSize).map(_ => Random.nextInt())
         val level = randomLevel()
         val startTime = System.currentTimeMillis()
-        manager.put(blockId, block.iterator, level, true)
+        manager.put(blockId, block.iterator, level, tellMaster = true)
         println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
         queue.add((blockId, block))
       }
@@ -101,7 +101,7 @@ private[spark] object ThreadingTest {
       conf)
     val blockManager = new BlockManager(
       "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
-      new SecurityManager(conf))
+      new SecurityManager(conf), new MapOutputTrackerMaster(conf))
     val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
     val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
     producers.foreach(_.start)

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0448919..7ebed51 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -62,8 +62,8 @@ private[spark] class MetadataCleaner(
 
 private[spark] object MetadataCleanerType extends Enumeration {
 
-  val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
-    SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
+  val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
+  SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
 
   type MetadataCleanerType = Value
 
@@ -78,15 +78,16 @@ private[spark] object MetadataCleaner {
     conf.getInt("spark.cleaner.ttl", -1)
   }
 
-  def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
-  {
-    conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString)
-      .toInt
+  def getDelaySeconds(
+      conf: SparkConf,
+      cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
+    conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt
   }
 
-  def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType,
-      delay: Int)
-  {
+  def setDelaySeconds(
+      conf: SparkConf,
+      cleanerType: MetadataCleanerType.MetadataCleanerType,
+      delay: Int) {
     conf.set(MetadataCleanerType.systemProperty(cleanerType),  delay.toString)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index ddbd084..8de75ba 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -17,48 +17,54 @@
 
 package org.apache.spark.util
 
+import java.util.Set
+import java.util.Map.Entry
 import java.util.concurrent.ConcurrentHashMap
 
-import scala.collection.JavaConversions
-import scala.collection.immutable
-import scala.collection.mutable.Map
+import scala.collection.{JavaConversions, mutable}
 
 import org.apache.spark.Logging
 
+private[spark] case class TimeStampedValue[V](value: V, timestamp: Long)
+
 /**
  * This is a custom implementation of scala.collection.mutable.Map which stores the insertion
  * timestamp along with each key-value pair. If specified, the timestamp of each pair can be
  * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
  * threshold time can then be removed using the clearOldValues method. This is intended to
  * be a drop-in replacement of scala.collection.mutable.HashMap.
- * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
- *                             updated when it is accessed
+ *
+ * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed
  */
-class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
-  extends Map[A, B]() with Logging {
-  val internalMap = new ConcurrentHashMap[A, (B, Long)]()
+private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+  extends mutable.Map[A, B]() with Logging {
+
+  private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]()
 
   def get(key: A): Option[B] = {
     val value = internalMap.get(key)
     if (value != null && updateTimeStampOnGet) {
-      internalMap.replace(key, value, (value._1, currentTime))
+      internalMap.replace(key, value, TimeStampedValue(value.value, currentTime))
     }
-    Option(value).map(_._1)
+    Option(value).map(_.value)
   }
 
   def iterator: Iterator[(A, B)] = {
-    val jIterator = internalMap.entrySet().iterator()
-    JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
+    val jIterator = getEntrySet.iterator
+    JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value))
   }
 
-  override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
+  def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet
+
+  override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
     val newMap = new TimeStampedHashMap[A, B1]
-    newMap.internalMap.putAll(this.internalMap)
-    newMap.internalMap.put(kv._1, (kv._2, currentTime))
+    val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]]
+    newMap.internalMap.putAll(oldInternalMap)
+    kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) }
     newMap
   }
 
-  override def - (key: A): Map[A, B] = {
+  override def - (key: A): mutable.Map[A, B] = {
     val newMap = new TimeStampedHashMap[A, B]
     newMap.internalMap.putAll(this.internalMap)
     newMap.internalMap.remove(key)
@@ -66,17 +72,10 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
   }
 
   override def += (kv: (A, B)): this.type = {
-    internalMap.put(kv._1, (kv._2, currentTime))
+    kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) }
     this
   }
 
-  // Should we return previous value directly or as Option ?
-  def putIfAbsent(key: A, value: B): Option[B] = {
-    val prev = internalMap.putIfAbsent(key, (value, currentTime))
-    if (prev != null) Some(prev._1) else None
-  }
-
-
   override def -= (key: A): this.type = {
     internalMap.remove(key)
     this
@@ -87,53 +86,65 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
   }
 
   override def apply(key: A): B = {
-    val value = internalMap.get(key)
-    if (value == null) throw new NoSuchElementException()
-    value._1
+    get(key).getOrElse { throw new NoSuchElementException() }
   }
 
-  override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
-    JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+  override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = {
+    JavaConversions.mapAsScalaConcurrentMap(internalMap)
+      .map { case (k, TimeStampedValue(v, t)) => (k, v) }
+      .filter(p)
   }
 
-  override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
+  override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]()
 
   override def size: Int = internalMap.size
 
   override def foreach[U](f: ((A, B)) => U) {
-    val iterator = internalMap.entrySet().iterator()
-    while(iterator.hasNext) {
-      val entry = iterator.next()
-      val kv = (entry.getKey, entry.getValue._1)
+    val it = getEntrySet.iterator
+    while(it.hasNext) {
+      val entry = it.next()
+      val kv = (entry.getKey, entry.getValue.value)
       f(kv)
     }
   }
 
-  def toMap: immutable.Map[A, B] = iterator.toMap
+  def putIfAbsent(key: A, value: B): Option[B] = {
+    val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime))
+    Option(prev).map(_.value)
+  }
+
+  def putAll(map: Map[A, B]) {
+    map.foreach { case (k, v) => update(k, v) }
+  }
+
+  def toMap: Map[A, B] = iterator.toMap
 
-  /**
-   * Removes old key-value pairs that have timestamp earlier than `threshTime`,
-   * calling the supplied function on each such entry before removing.
-   */
   def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
-    val iterator = internalMap.entrySet().iterator()
-    while (iterator.hasNext) {
-      val entry = iterator.next()
-      if (entry.getValue._2 < threshTime) {
-        f(entry.getKey, entry.getValue._1)
+    val it = getEntrySet.iterator
+    while (it.hasNext) {
+      val entry = it.next()
+      if (entry.getValue.timestamp < threshTime) {
+        f(entry.getKey, entry.getValue.value)
         logDebug("Removing key " + entry.getKey)
-        iterator.remove()
+        it.remove()
       }
     }
   }
 
-  /**
-   * Removes old key-value pairs that have timestamp earlier than `threshTime`
-   */
+  /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */
   def clearOldValues(threshTime: Long) {
     clearOldValues(threshTime, (_, _) => ())
   }
 
-  private def currentTime: Long = System.currentTimeMillis()
+  private def currentTime: Long = System.currentTimeMillis
 
+  // For testing
+
+  def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = {
+    Option(internalMap.get(key))
+  }
+
+  def getTimestamp(key: A): Option[Long] = {
+    getTimeStampedValue(key).map(_.timestamp)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
new file mode 100644
index 0000000..b65017d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.lang.ref.WeakReference
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+
+/**
+ * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
+ *
+ * If the value is garbage collected and the weak reference is null, get() will return a
+ * non-existent value. These entries are removed from the map periodically (every N inserts), as
+ * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are
+ * older than a particular threshold can be removed using the clearOldValues method.
+ *
+ * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it
+ * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap,
+ * so all operations on this HashMap are thread-safe.
+ *
+ * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
+ */
+private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+  extends mutable.Map[A, B]() with Logging {
+
+  import TimeStampedWeakValueHashMap._
+
+  private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)
+  private val insertCount = new AtomicInteger(0)
+
+  /** Return a map consisting only of entries whose values are still strongly reachable. */
+  private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null }
+
+  def get(key: A): Option[B] = internalMap.get(key)
+
+  def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator
+
+  override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
+    val newMap = new TimeStampedWeakValueHashMap[A, B1]
+    val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]]
+    newMap.internalMap.putAll(oldMap.toMap)
+    newMap.internalMap += kv
+    newMap
+  }
+
+  override def - (key: A): mutable.Map[A, B] = {
+    val newMap = new TimeStampedWeakValueHashMap[A, B]
+    newMap.internalMap.putAll(nonNullReferenceMap.toMap)
+    newMap.internalMap -= key
+    newMap
+  }
+
+  override def += (kv: (A, B)): this.type = {
+    internalMap += kv
+    if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) {
+      clearNullValues()
+    }
+    this
+  }
+
+  override def -= (key: A): this.type = {
+    internalMap -= key
+    this
+  }
+
+  override def update(key: A, value: B) = this += ((key, value))
+
+  override def apply(key: A): B = internalMap.apply(key)
+
+  override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p)
+
+  override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()
+
+  override def size: Int = internalMap.size
+
+  override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f)
+
+  def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
+
+  def toMap: Map[A, B] = iterator.toMap
+
+  /** Remove old key-value pairs with timestamps earlier than `threshTime`. */
+  def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)
+
+  /** Remove entries with values that are no longer strongly reachable. */
+  def clearNullValues() {
+    val it = internalMap.getEntrySet.iterator
+    while (it.hasNext) {
+      val entry = it.next()
+      if (entry.getValue.value.get == null) {
+        logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.")
+        it.remove()
+      }
+    }
+  }
+
+  // For testing
+
+  def getTimestamp(key: A): Option[Long] = {
+    internalMap.getTimeStampedValue(key).map(_.timestamp)
+  }
+
+  def getReference(key: A): Option[WeakReference[B]] = {
+    internalMap.getTimeStampedValue(key).map(_.value)
+  }
+}
+
+/**
+ * Helper methods for converting to and from WeakReferences.
+ */
+private object TimeStampedWeakValueHashMap {
+
+  // Number of inserts after which entries with null references are removed
+  val CLEAR_NULL_VALUES_INTERVAL = 100
+
+  /* Implicit conversion methods to WeakReferences. */
+
+  implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)
+
+  implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = {
+    kv match { case (k, v) => (k, toWeakReference(v)) }
+  }
+
+  implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = {
+    (kv: (K, WeakReference[V])) => p(kv)
+  }
+
+  /* Implicit conversion methods from WeakReferences. */
+
+  implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get
+
+  implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
+    v match {
+      case Some(ref) => Option(fromWeakReference(ref))
+      case None => None
+    }
+  }
+
+  implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
+    kv match { case (k, v) => (k, fromWeakReference(v)) }
+  }
+
+  implicit def fromWeakReferenceIterator[K, V](
+      it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = {
+    it.map(fromWeakReferenceTuple)
+  }
+
+  implicit def fromWeakReferenceMap[K, V](
+      map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
+    mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 4435b21..59da51f 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -499,10 +499,10 @@ private[spark] object Utils extends Logging {
   private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
 
   def parseHostPort(hostPort: String): (String,  Int) = {
-    {
-      // Check cache first.
-      val cached = hostPortParseResults.get(hostPort)
-      if (cached != null) return cached
+    // Check cache first.
+    val cached = hostPortParseResults.get(hostPort)
+    if (cached != null) {
+      return cached
     }
 
     val indx: Int = hostPort.lastIndexOf(':')

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
index d2e303d..c5f24c6 100644
--- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -56,7 +56,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, 
       conf = conf, securityManager = securityManagerBad)
-    val slaveTracker = new MapOutputTracker(conf)
+    val slaveTracker = new MapOutputTrackerWorker(conf)
     val selection = slaveSystem.actorSelection(
       s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
     val timeout = AkkaUtils.lookupTimeout(conf)
@@ -93,7 +93,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, 
       conf = badconf, securityManager = securityManagerBad)
-    val slaveTracker = new MapOutputTracker(conf)
+    val slaveTracker = new MapOutputTrackerWorker(conf)
     val selection = slaveSystem.actorSelection(
       s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
     val timeout = AkkaUtils.lookupTimeout(conf)
@@ -147,7 +147,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
       conf = goodconf, securityManager = securityManagerGood)
-    val slaveTracker = new MapOutputTracker(conf)
+    val slaveTracker = new MapOutputTrackerWorker(conf)
     val selection = slaveSystem.actorSelection(
       s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
     val timeout = AkkaUtils.lookupTimeout(conf)
@@ -200,7 +200,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
       conf = badconf, securityManager = securityManagerBad)
-    val slaveTracker = new MapOutputTracker(conf)
+    val slaveTracker = new MapOutputTrackerWorker(conf)
     val selection = slaveSystem.actorSelection(
       s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
     val timeout = AkkaUtils.lookupTimeout(conf)

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index 96ba392..c993625 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -19,68 +19,297 @@ package org.apache.spark
 
 import org.scalatest.FunSuite
 
-class BroadcastSuite extends FunSuite with LocalSparkContext {
+import org.apache.spark.storage._
+import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
+import org.apache.spark.storage.BroadcastBlockId
 
+class BroadcastSuite extends FunSuite with LocalSparkContext {
 
-  override def afterEach() {
-    super.afterEach()
-    System.clearProperty("spark.broadcast.factory")
-  }
+  private val httpConf = broadcastConf("HttpBroadcastFactory")
+  private val torrentConf = broadcastConf("TorrentBroadcastFactory")
 
   test("Using HttpBroadcast locally") {
-    System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
-    sc = new SparkContext("local", "test")
-    val list = List(1, 2, 3, 4)
-    val listBroadcast = sc.broadcast(list)
-    val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
-    assert(results.collect.toSet === Set((1, 10), (2, 10)))
+    sc = new SparkContext("local", "test", httpConf)
+    val list = List[Int](1, 2, 3, 4)
+    val broadcast = sc.broadcast(list)
+    val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === Set((1, 10), (2, 10)))
   }
 
   test("Accessing HttpBroadcast variables from multiple threads") {
-    System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
-    sc = new SparkContext("local[10]", "test")
-    val list = List(1, 2, 3, 4)
-    val listBroadcast = sc.broadcast(list)
-    val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
-    assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+    sc = new SparkContext("local[10]", "test", httpConf)
+    val list = List[Int](1, 2, 3, 4)
+    val broadcast = sc.broadcast(list)
+    val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
   }
 
   test("Accessing HttpBroadcast variables in a local cluster") {
-    System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
     val numSlaves = 4
-    sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
-    val list = List(1, 2, 3, 4)
-    val listBroadcast = sc.broadcast(list)
-    val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
-    assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+    sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
+    val list = List[Int](1, 2, 3, 4)
+    val broadcast = sc.broadcast(list)
+    val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
   }
 
   test("Using TorrentBroadcast locally") {
-    System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
-    sc = new SparkContext("local", "test")
-    val list = List(1, 2, 3, 4)
-    val listBroadcast = sc.broadcast(list)
-    val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
-    assert(results.collect.toSet === Set((1, 10), (2, 10)))
+    sc = new SparkContext("local", "test", torrentConf)
+    val list = List[Int](1, 2, 3, 4)
+    val broadcast = sc.broadcast(list)
+    val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === Set((1, 10), (2, 10)))
   }
 
   test("Accessing TorrentBroadcast variables from multiple threads") {
-    System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
-    sc = new SparkContext("local[10]", "test")
-    val list = List(1, 2, 3, 4)
-    val listBroadcast = sc.broadcast(list)
-    val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
-    assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+    sc = new SparkContext("local[10]", "test", torrentConf)
+    val list = List[Int](1, 2, 3, 4)
+    val broadcast = sc.broadcast(list)
+    val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
   }
 
   test("Accessing TorrentBroadcast variables in a local cluster") {
-    System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
     val numSlaves = 4
-    sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
-    val list = List(1, 2, 3, 4)
-    val listBroadcast = sc.broadcast(list)
-    val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
-    assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+    sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
+    val list = List[Int](1, 2, 3, 4)
+    val broadcast = sc.broadcast(list)
+    val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+  }
+
+  test("Unpersisting HttpBroadcast on executors only in local mode") {
+    testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
+  }
+
+  test("Unpersisting HttpBroadcast on executors and driver in local mode") {
+    testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
+  }
+
+  test("Unpersisting HttpBroadcast on executors only in distributed mode") {
+    testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
+  }
+
+  test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
+    testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
+  }
+
+  test("Unpersisting TorrentBroadcast on executors only in local mode") {
+    testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
+  }
+
+  test("Unpersisting TorrentBroadcast on executors and driver in local mode") {
+    testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true)
+  }
+
+  test("Unpersisting TorrentBroadcast on executors only in distributed mode") {
+    testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false)
+  }
+
+  test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
+    testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
+  }
+  /**
+   * Verify the persistence of state associated with an HttpBroadcast in either local mode or
+   * local-cluster mode (when distributed = true).
+   *
+   * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+   * In between each step, this test verifies that the broadcast blocks and the broadcast file
+   * are present only on the expected nodes.
+   */
+  private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+    val numSlaves = if (distributed) 2 else 0
+
+    def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
+
+    // Verify that the broadcast file is created, and blocks are persisted only on the driver
+    def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+      assert(blockIds.size === 1)
+      val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+      assert(statuses.size === 1)
+      statuses.head match { case (bm, status) =>
+        assert(bm.executorId === "<driver>", "Block should only be on the driver")
+        assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+        assert(status.memSize > 0, "Block should be in memory store on the driver")
+        assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+      }
+      if (distributed) {
+        // this file is only generated in distributed mode
+        assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+      }
+    }
+
+    // Verify that blocks are persisted in both the executors and the driver
+    def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+      assert(blockIds.size === 1)
+      val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+      assert(statuses.size === numSlaves + 1)
+      statuses.foreach { case (_, status) =>
+        assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+        assert(status.memSize > 0, "Block should be in memory store")
+        assert(status.diskSize === 0, "Block should not be in disk store")
+      }
+    }
+
+    // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+    // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
+    def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+      assert(blockIds.size === 1)
+      val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+      val expectedNumBlocks = if (removeFromDriver) 0 else 1
+      val possiblyNot = if (removeFromDriver) "" else " not"
+      assert(statuses.size === expectedNumBlocks,
+        "Block should%s be unpersisted on the driver".format(possiblyNot))
+      if (distributed && removeFromDriver) {
+        // this file is only generated in distributed mode
+        assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+          "Broadcast file should%s be deleted".format(possiblyNot))
+      }
+    }
+
+    testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+      afterUsingBroadcast, afterUnpersist, removeFromDriver)
+  }
+
+  /**
+   * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
+   *
+   * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+   * In between each step, this test verifies that the broadcast blocks are present only on the
+   * expected nodes.
+   */
+  private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+    val numSlaves = if (distributed) 2 else 0
+
+    def getBlockIds(id: Long) = {
+      val broadcastBlockId = BroadcastBlockId(id)
+      val metaBlockId = BroadcastBlockId(id, "meta")
+      // Assume broadcast value is small enough to fit into 1 piece
+      val pieceBlockId = BroadcastBlockId(id, "piece0")
+      if (distributed) {
+        // the metadata and piece blocks are generated only in distributed mode
+        Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
+      } else {
+        Seq[BroadcastBlockId](broadcastBlockId)
+      }
+    }
+
+    // Verify that blocks are persisted only on the driver
+    def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+      blockIds.foreach { blockId =>
+        val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+        assert(statuses.size === 1)
+        statuses.head match { case (bm, status) =>
+          assert(bm.executorId === "<driver>", "Block should only be on the driver")
+          assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+          assert(status.memSize > 0, "Block should be in memory store on the driver")
+          assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+        }
+      }
+    }
+
+    // Verify that blocks are persisted in both the executors and the driver
+    def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+      blockIds.foreach { blockId =>
+        val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+        if (blockId.field == "meta") {
+          // Meta data is only on the driver
+          assert(statuses.size === 1)
+          statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
+        } else {
+          // Other blocks are on both the executors and the driver
+          assert(statuses.size === numSlaves + 1,
+            blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
+          statuses.foreach { case (_, status) =>
+            assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+            assert(status.memSize > 0, "Block should be in memory store")
+            assert(status.diskSize === 0, "Block should not be in disk store")
+          }
+        }
+      }
+    }
+
+    // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+    // is true.
+    def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+      val expectedNumBlocks = if (removeFromDriver) 0 else 1
+      val possiblyNot = if (removeFromDriver) "" else " not"
+      blockIds.foreach { blockId =>
+        val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+        assert(statuses.size === expectedNumBlocks,
+          "Block should%s be unpersisted on the driver".format(possiblyNot))
+      }
+    }
+
+    testUnpersistBroadcast(distributed, numSlaves,  torrentConf, getBlockIds, afterCreation,
+      afterUsingBroadcast, afterUnpersist, removeFromDriver)
+  }
+
+  /**
+   * This test runs in 4 steps:
+   *
+   * 1) Create broadcast variable, and verify that all state is persisted on the driver.
+   * 2) Use the broadcast variable on all executors, and verify that all state is persisted
+   *    on both the driver and the executors.
+   * 3) Unpersist the broadcast, and verify that all state is removed where they should be.
+   * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable.
+   */
+  private def testUnpersistBroadcast(
+      distributed: Boolean,
+      numSlaves: Int,  // used only when distributed = true
+      broadcastConf: SparkConf,
+      getBlockIds: Long => Seq[BroadcastBlockId],
+      afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+      afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+      afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+      removeFromDriver: Boolean) {
+
+    sc = if (distributed) {
+      new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
+    } else {
+      new SparkContext("local", "test", broadcastConf)
+    }
+    val blockManagerMaster = sc.env.blockManager.master
+    val list = List[Int](1, 2, 3, 4)
+
+    // Create broadcast variable
+    val broadcast = sc.broadcast(list)
+    val blocks = getBlockIds(broadcast.id)
+    afterCreation(blocks, blockManagerMaster)
+
+    // Use broadcast variable on all executors
+    val partitions = 10
+    assert(partitions > numSlaves)
+    val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+    assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+    afterUsingBroadcast(blocks, blockManagerMaster)
+
+    // Unpersist broadcast
+    if (removeFromDriver) {
+      broadcast.destroy(blocking = true)
+    } else {
+      broadcast.unpersist(blocking = true)
+    }
+    afterUnpersist(blocks, blockManagerMaster)
+
+    // If the broadcast is removed from driver, all subsequent uses of the broadcast variable
+    // should throw SparkExceptions. Otherwise, the result should be the same as before.
+    if (removeFromDriver) {
+      // Using this variable on the executors crashes them, which hangs the test.
+      // Instead, crash the driver by directly accessing the broadcast value.
+      intercept[SparkException] { broadcast.value }
+      intercept[SparkException] { broadcast.unpersist() }
+      intercept[SparkException] { broadcast.destroy(blocking = true) }
+    } else {
+      val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+      assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+    }
   }
 
+  /** Helper method to create a SparkConf that uses the given broadcast factory. */
+  private def broadcastConf(factoryName: String): SparkConf = {
+    val conf = new SparkConf
+    conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
+    conf
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
new file mode 100644
index 0000000..e50981c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.ref.WeakReference
+
+import scala.collection.mutable.{HashSet, SynchronizedSet}
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
+
+class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+  implicit val defaultTimeout = timeout(10000 millis)
+  val conf = new SparkConf()
+    .setMaster("local[2]")
+    .setAppName("ContextCleanerSuite")
+    .set("spark.cleaner.referenceTracking.blocking", "true")
+
+  before {
+    sc = new SparkContext(conf)
+  }
+
+  after {
+    if (sc != null) {
+      sc.stop()
+      sc = null
+    }
+  }
+
+
+  test("cleanup RDD") {
+    val rdd = newRDD.persist()
+    val collected = rdd.collect().toList
+    val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+
+    // Explicit cleanup
+    cleaner.doCleanupRDD(rdd.id, blocking = true)
+    tester.assertCleanup()
+
+    // Verify that RDDs can be re-executed after cleaning up
+    assert(rdd.collect().toList === collected)
+  }
+
+  test("cleanup shuffle") {
+    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
+    val collected = rdd.collect().toList
+    val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
+
+    // Explicit cleanup
+    shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
+    tester.assertCleanup()
+
+    // Verify that shuffles can be re-executed after cleaning up
+    assert(rdd.collect().toList === collected)
+  }
+
+  test("cleanup broadcast") {
+    val broadcast = newBroadcast
+    val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+
+    // Explicit cleanup
+    cleaner.doCleanupBroadcast(broadcast.id, blocking = true)
+    tester.assertCleanup()
+  }
+
+  test("automatically cleanup RDD") {
+    var rdd = newRDD.persist()
+    rdd.count()
+
+    // Test that GC does not cause RDD cleanup due to a strong reference
+    val preGCTester =  new CleanerTester(sc, rddIds = Seq(rdd.id))
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
+    }
+
+    // Test that GC causes RDD cleanup after dereferencing the RDD
+    val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+    rdd = null // Make RDD out of scope
+    runGC()
+    postGCTester.assertCleanup()
+  }
+
+  test("automatically cleanup shuffle") {
+    var rdd = newShuffleRDD
+    rdd.count()
+
+    // Test that GC does not cause shuffle cleanup due to a strong reference
+    val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
+    }
+
+    // Test that GC causes shuffle cleanup after dereferencing the RDD
+    val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+    rdd = null  // Make RDD out of scope, so that corresponding shuffle goes out of scope
+    runGC()
+    postGCTester.assertCleanup()
+  }
+
+  test("automatically cleanup broadcast") {
+    var broadcast = newBroadcast
+
+    // Test that GC does not cause broadcast cleanup due to a strong reference
+    val preGCTester =  new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
+    }
+
+    // Test that GC causes broadcast cleanup after dereferencing the broadcast variable
+    val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+    broadcast = null  // Make broadcast variable out of scope
+    runGC()
+    postGCTester.assertCleanup()
+  }
+
+  test("automatically cleanup RDD + shuffle + broadcast") {
+    val numRdds = 100
+    val numBroadcasts = 4 // Broadcasts are more costly
+    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+    val rddIds = sc.persistentRdds.keys.toSeq
+    val shuffleIds = 0 until sc.newShuffleId
+    val broadcastIds = 0L until numBroadcasts
+
+    val preGCTester =  new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
+    }
+
+    // Test that GC triggers the cleanup of all variables after the dereferencing them
+    val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+    broadcastBuffer.clear()
+    rddBuffer.clear()
+    runGC()
+    postGCTester.assertCleanup()
+  }
+
+  test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
+    sc.stop()
+
+    val conf2 = new SparkConf()
+      .setMaster("local-cluster[2, 1, 512]")
+      .setAppName("ContextCleanerSuite")
+      .set("spark.cleaner.referenceTracking.blocking", "true")
+    sc = new SparkContext(conf2)
+
+    val numRdds = 10
+    val numBroadcasts = 4 // Broadcasts are more costly
+    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+    val rddIds = sc.persistentRdds.keys.toSeq
+    val shuffleIds = 0 until sc.newShuffleId
+    val broadcastIds = 0L until numBroadcasts
+
+    val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
+    }
+
+    // Test that GC triggers the cleanup of all variables after the dereferencing them
+    val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+    broadcastBuffer.clear()
+    rddBuffer.clear()
+    runGC()
+    postGCTester.assertCleanup()
+  }
+
+  //------ Helper functions ------
+
+  def newRDD = sc.makeRDD(1 to 10)
+  def newPairRDD = newRDD.map(_ -> 1)
+  def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
+  def newBroadcast = sc.broadcast(1 to 100)
+  def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+    def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
+      rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
+        getAllDependencies(dep.rdd)
+      }
+    }
+    val rdd = newShuffleRDD
+
+    // Get all the shuffle dependencies
+    val shuffleDeps = getAllDependencies(rdd)
+      .filter(_.isInstanceOf[ShuffleDependency[_, _]])
+      .map(_.asInstanceOf[ShuffleDependency[_, _]])
+    (rdd, shuffleDeps)
+  }
+
+  def randomRdd = {
+    val rdd: RDD[_] = Random.nextInt(3) match {
+      case 0 => newRDD
+      case 1 => newShuffleRDD
+      case 2 => newPairRDD.join(newPairRDD)
+    }
+    if (Random.nextBoolean()) rdd.persist()
+    rdd.count()
+    rdd
+  }
+
+  def randomBroadcast = {
+    sc.broadcast(Random.nextInt(Int.MaxValue))
+  }
+
+  /** Run GC and make sure it actually has run */
+  def runGC() {
+    val weakRef = new WeakReference(new Object())
+    val startTime = System.currentTimeMillis
+    System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+    // Wait until a weak reference object has been GCed
+    while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+      System.gc()
+      Thread.sleep(200)
+    }
+  }
+
+  def cleaner = sc.cleaner.get
+}
+
+
+/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */
+class CleanerTester(
+    sc: SparkContext,
+    rddIds: Seq[Int] = Seq.empty,
+    shuffleIds: Seq[Int] = Seq.empty,
+    broadcastIds: Seq[Long] = Seq.empty)
+  extends Logging {
+
+  val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds
+  val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds
+  val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds
+  val isDistributed = !sc.isLocal
+
+  val cleanerListener = new CleanerListener {
+    def rddCleaned(rddId: Int): Unit = {
+      toBeCleanedRDDIds -= rddId
+      logInfo("RDD "+ rddId + " cleaned")
+    }
+
+    def shuffleCleaned(shuffleId: Int): Unit = {
+      toBeCleanedShuffleIds -= shuffleId
+      logInfo("Shuffle " + shuffleId + " cleaned")
+    }
+
+    def broadcastCleaned(broadcastId: Long): Unit = {
+      toBeCleanedBroadcstIds -= broadcastId
+      logInfo("Broadcast" + broadcastId + " cleaned")
+    }
+  }
+
+  val MAX_VALIDATION_ATTEMPTS = 10
+  val VALIDATION_ATTEMPT_INTERVAL = 100
+
+  logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString)
+  preCleanupValidate()
+  sc.cleaner.get.attachListener(cleanerListener)
+
+  /** Assert that all the stuff has been cleaned up */
+  def assertCleanup()(implicit waitTimeout: Eventually.Timeout) {
+    try {
+      eventually(waitTimeout, interval(100 millis)) {
+        assert(isAllCleanedUp)
+      }
+      postCleanupValidate()
+    } finally {
+      logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString)
+    }
+  }
+
+  /** Verify that RDDs, shuffles, etc. occupy resources */
+  private def preCleanupValidate() {
+    assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup")
+
+    // Verify the RDDs have been persisted and blocks are present
+    rddIds.foreach { rddId =>
+      assert(
+        sc.persistentRdds.contains(rddId),
+        "RDD " + rddId + " have not been persisted, cannot start cleaner test"
+      )
+
+      assert(
+        !getRDDBlocks(rddId).isEmpty,
+        "Blocks of RDD " + rddId + " cannot be found in block manager, " +
+          "cannot start cleaner test"
+      )
+    }
+
+    // Verify the shuffle ids are registered and blocks are present
+    shuffleIds.foreach { shuffleId =>
+      assert(
+        mapOutputTrackerMaster.containsShuffle(shuffleId),
+        "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test"
+      )
+
+      assert(
+        !getShuffleBlocks(shuffleId).isEmpty,
+        "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " +
+          "cannot start cleaner test"
+      )
+    }
+
+    // Verify that the broadcast blocks are present
+    broadcastIds.foreach { broadcastId =>
+      assert(
+        !getBroadcastBlocks(broadcastId).isEmpty,
+        "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " +
+          "cannot start cleaner test"
+      )
+    }
+  }
+
+  /**
+   * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is
+   * as there is not guarantee on how long it will take clean up the resources.
+   */
+  private def postCleanupValidate() {
+    // Verify the RDDs have been persisted and blocks are present
+    rddIds.foreach { rddId =>
+      assert(
+        !sc.persistentRdds.contains(rddId),
+        "RDD " + rddId + " was not cleared from sc.persistentRdds"
+      )
+
+      assert(
+        getRDDBlocks(rddId).isEmpty,
+        "Blocks of RDD " + rddId + " were not cleared from block manager"
+      )
+    }
+
+    // Verify the shuffle ids are registered and blocks are present
+    shuffleIds.foreach { shuffleId =>
+      assert(
+        !mapOutputTrackerMaster.containsShuffle(shuffleId),
+        "Shuffle " + shuffleId + " was not deregistered from map output tracker"
+      )
+
+      assert(
+        getShuffleBlocks(shuffleId).isEmpty,
+        "Blocks of shuffle " + shuffleId + " were not cleared from block manager"
+      )
+    }
+
+    // Verify that the broadcast blocks are present
+    broadcastIds.foreach { broadcastId =>
+      assert(
+        getBroadcastBlocks(broadcastId).isEmpty,
+        "Blocks of broadcast " + broadcastId + " were not cleared from block manager"
+      )
+    }
+  }
+
+  private def uncleanedResourcesToString = {
+    s"""
+      |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")}
+      |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")}
+      |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")}
+    """.stripMargin
+  }
+
+  private def isAllCleanedUp =
+    toBeCleanedRDDIds.isEmpty &&
+    toBeCleanedShuffleIds.isEmpty &&
+    toBeCleanedBroadcstIds.isEmpty
+
+  private def getRDDBlocks(rddId: Int): Seq[BlockId] = {
+    blockManager.master.getMatchingBlockIds( _ match {
+      case RDDBlockId(`rddId`, _) => true
+      case _ => false
+    }, askSlaves = true)
+  }
+
+  private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
+    blockManager.master.getMatchingBlockIds( _ match {
+      case ShuffleBlockId(`shuffleId`, _, _) => true
+      case _ => false
+    }, askSlaves = true)
+  }
+
+  private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = {
+    blockManager.master.getMatchingBlockIds( _ match {
+      case BroadcastBlockId(`broadcastId`, _) => true
+      case _ => false
+    }, askSlaves = true)
+  }
+
+  private def blockManager = sc.env.blockManager
+  private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index a5bd72e..6b2571c 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -57,12 +57,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
     tracker.stop()
   }
 
-  test("master register and fetch") {
+  test("master register shuffle and fetch") {
     val actorSystem = ActorSystem("test")
     val tracker = new MapOutputTrackerMaster(conf)
     tracker.trackerActor =
       actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
     tracker.registerShuffle(10, 2)
+    assert(tracker.containsShuffle(10))
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
     val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
@@ -77,7 +78,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
     tracker.stop()
   }
 
-  test("master register and unregister and fetch") {
+  test("master register and unregister shuffle") {
+    val actorSystem = ActorSystem("test")
+    val tracker = new MapOutputTrackerMaster(conf)
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+    tracker.registerShuffle(10, 2)
+    val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+    val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+    tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+      Array(compressedSize1000, compressedSize10000)))
+    tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+      Array(compressedSize10000, compressedSize1000)))
+    assert(tracker.containsShuffle(10))
+    assert(tracker.getServerStatuses(10, 0).nonEmpty)
+    tracker.unregisterShuffle(10)
+    assert(!tracker.containsShuffle(10))
+    assert(tracker.getServerStatuses(10, 0).isEmpty)
+  }
+
+  test("master register shuffle and unregister map output and fetch") {
     val actorSystem = ActorSystem("test")
     val tracker = new MapOutputTrackerMaster(conf)
     tracker.trackerActor =
@@ -114,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
       securityManager = new SecurityManager(conf))
-    val slaveTracker = new MapOutputTracker(conf)
+    val slaveTracker = new MapOutputTrackerWorker(conf)
     val selection = slaveSystem.actorSelection(
       s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
     val timeout = AkkaUtils.lookupTimeout(conf)


Mime
View raw message