spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject [1/4] git commit: Split MapOutputTracker into Master/Worker classes.
Date Tue, 22 Oct 2013 17:20:46 GMT
Updated Branches:
  refs/heads/master b84193c5b -> a0e08f0fb


Split MapOutputTracker into Master/Worker classes.

Previously, MapOutputTracker contained fields and methods that
were only applicable to the master or worker instances.  This
commit introduces a MasterMapOutputTracker class to prevent
the master-specific methods from being accessed on workers.

I also renamed a few methods and made others protected/private.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/9159d2d0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/9159d2d0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/9159d2d0

Branch: refs/heads/master
Commit: 9159d2d09d459e8879fee2222edd53860adc2b44
Parents: 6511bbe
Author: Josh Rosen <joshrosen@eecs.berkeley.edu>
Authored: Sat Oct 19 16:44:18 2013 -0700
Committer: Josh Rosen <rosenville@gmail.com>
Committed: Sat Oct 19 20:01:22 2013 -0700

----------------------------------------------------------------------
 .../org/apache/spark/MapOutputTracker.scala     | 172 ++++++++++---------
 .../main/scala/org/apache/spark/SparkEnv.scala  |   8 +-
 .../apache/spark/scheduler/DAGScheduler.scala   |   5 +-
 .../apache/spark/MapOutputTrackerSuite.scala    |  20 +--
 .../spark/scheduler/DAGSchedulerSuite.scala     |   6 +-
 5 files changed, 113 insertions(+), 98 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/9159d2d0/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1e3f1eb..f0f8f2d 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,13 +20,11 @@ package org.apache.spark
 import java.io._
 import java.util.zip.{GZIPInputStream, GZIPOutputStream}
 
-import scala.collection.mutable.HashMap
 import scala.collection.mutable.HashSet
 
 import akka.actor._
 import akka.dispatch._
 import akka.pattern.ask
-import akka.remote._
 import akka.util.Duration
 
 
@@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester:
String
   extends MapOutputTrackerMessage
 private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
 
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with
Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+  extends Actor with Logging {
   def receive = {
     case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
       logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
-      sender ! tracker.getSerializedLocations(shuffleId)
+      sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
 
     case StopMapOutputTracker =>
       logInfo("MapOutputTrackerActor stopped!")
@@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging {
   // Set to the MapOutputTrackerActor living on the driver
   var trackerActor: ActorRef = _
 
-  private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+  protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
 
   // Incremented every time a fetch fails so that client nodes know to clear
   // their cache of map output locations if this happens.
-  private var epoch: Long = 0
-  private val epochLock = new java.lang.Object
+  protected var epoch: Long = 0
+  protected val epochLock = new java.lang.Object
 
-  // Cache a serialized version of the output statuses for each shuffle to send them out
faster
-  var cacheEpoch = epoch
-  private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
-  val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
+  private val metadataCleaner =
+    new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
 
   // Send a message to the trackerActor and get its result within a default timeout, or
   // throw a SparkException if this fails.
-  def askTracker(message: Any): Any = {
+  private def askTracker(message: Any): Any = {
     try {
       val future = trackerActor.ask(message)(timeout)
       return Await.result(future, timeout)
@@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging {
   }
 
   // Send a one-way message to the trackerActor, to which we expect it to reply with true.
-  def communicate(message: Any) {
+  private def communicate(message: Any) {
     if (askTracker(message) != true) {
       throw new SparkException("Error reply received from MapOutputTracker")
     }
   }
 
-  def registerShuffle(shuffleId: Int, numMaps: Int) {
-    if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
-      throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
-    }
-  }
-
-  def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
-    var array = mapStatuses(shuffleId)
-    array.synchronized {
-      array(mapId) = status
-    }
-  }
-
-  def registerMapOutputs(
-      shuffleId: Int,
-      statuses: Array[MapStatus],
-      changeEpoch: Boolean = false) {
-    mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
-    if (changeEpoch) {
-      incrementEpoch()
-    }
-  }
-
-  def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
-    var arrayOpt = mapStatuses.get(shuffleId)
-    if (arrayOpt.isDefined && arrayOpt.get != null) {
-      var array = arrayOpt.get
-      array.synchronized {
-        if (array(mapId) != null && array(mapId).location == bmAddress) {
-          array(mapId) = null
-        }
-      }
-      incrementEpoch()
-    } else {
-      throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
-    }
-  }
-
   // Remembers which map output locations are currently being fetched on a worker
   private val fetching = new HashSet[Int]
 
@@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
         try {
           val fetchedBytes =
             askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
-          fetchedStatuses = deserializeStatuses(fetchedBytes)
+          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
           logInfo("Got the output locations")
           mapStatuses.put(shuffleId, fetchedStatuses)
         } finally {
@@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging {
     }
   }
 
-  private def cleanup(cleanupTime: Long) {
+  protected def cleanup(cleanupTime: Long) {
     mapStatuses.clearOldValues(cleanupTime)
-    cachedSerializedStatuses.clearOldValues(cleanupTime)
   }
 
   def stop() {
@@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging {
     trackerActor = null
   }
 
-  // Called on master to increment the epoch number
-  def incrementEpoch() {
-    epochLock.synchronized {
-      epoch += 1
-      logDebug("Increasing epoch to " + epoch)
-    }
-  }
-
-  // Called on master or workers to get current epoch number
+  // Called to get current epoch number
   def getEpoch: Long = {
     epochLock.synchronized {
       return epoch
@@ -228,14 +177,63 @@ private[spark] class MapOutputTracker extends Logging {
     epochLock.synchronized {
       if (newEpoch > epoch) {
         logInfo("Updating epoch to " + newEpoch + " and clearing cache")
-        // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
-        mapStatuses.clear()
         epoch = newEpoch
+        mapStatuses.clear()
+      }
+    }
+  }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+  // Cache a serialized version of the output statuses for each shuffle to send them out
faster
+  private var cacheEpoch = epoch
+  private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+  def registerShuffle(shuffleId: Int, numMaps: Int) {
+    if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+      throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+    }
+  }
+
+  def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+    val array = mapStatuses(shuffleId)
+    array.synchronized {
+      array(mapId) = status
+    }
+  }
+
+  def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus],
+                         changeEpoch: Boolean = false) {
+    mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+    if (changeEpoch) {
+      incrementEpoch()
+    }
+  }
+
+  def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+    val arrayOpt = mapStatuses.get(shuffleId)
+    if (arrayOpt.isDefined && arrayOpt.get != null) {
+      val array = arrayOpt.get
+      array.synchronized {
+        if (array(mapId) != null && array(mapId).location == bmAddress) {
+          array(mapId) = null
+        }
       }
+      incrementEpoch()
+    } else {
+      throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
     }
   }
 
-  def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+  def incrementEpoch() {
+    epochLock.synchronized {
+      epoch += 1
+      logDebug("Increasing epoch to " + epoch)
+    }
+  }
+
+  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
     var statuses: Array[MapStatus] = null
     var epochGotten: Long = -1
     epochLock.synchronized {
@@ -253,7 +251,7 @@ private[spark] class MapOutputTracker extends Logging {
     }
     // If we got here, we failed to find the serialized locations in the cache, so we pulled
     // out a snapshot of the locations as "locs"; let's serialize and return that
-    val bytes = serializeStatuses(statuses)
+    val bytes = MapOutputTracker.serializeMapStatuses(statuses)
     logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
     // Add them into the table only if the epoch hasn't changed while we were working
     epochLock.synchronized {
@@ -261,13 +259,34 @@ private[spark] class MapOutputTracker extends Logging {
         cachedSerializedStatuses(shuffleId) = bytes
       }
     }
-    return bytes
+    bytes
+  }
+
+  protected override def cleanup(cleanupTime: Long) {
+    super.cleanup(cleanupTime)
+    cachedSerializedStatuses.clearOldValues(cleanupTime)
+  }
+
+  override def stop() {
+    super.stop()
+    cachedSerializedStatuses.clear()
+  }
+
+  override def updateEpoch(newEpoch: Long) {
+    // This might be called on the MapOutputTrackerMaster if we're running in local mode:
+    epochLock.synchronized {
+      assert (newEpoch == epoch)
+    }
   }
+}
+
+private[spark] object MapOutputTracker {
+  private val LOG_BASE = 1.1
 
   // Serialize an array of map output locations into an efficient byte format so that we
can send
   // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They
will
   // generally be pretty compressible because many map outputs will be on the same hostname.
-  private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+  def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
     val out = new ByteArrayOutputStream
     val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
     // Since statuses can be modified in parallel, sync on it
@@ -278,18 +297,11 @@ private[spark] class MapOutputTracker extends Logging {
     out.toByteArray
   }
 
-  // Opposite of serializeStatuses.
-  def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+  // Opposite of serializeMapStatuses.
+  def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
     val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
-    objIn.readObject().
-      // // drop all null's from status - not sure why they are occuring though. Causes NPE
downstream in slave if present
-      // comment this out - nulls could be due to missing location ? 
-      asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+    objIn.readObject().asInstanceOf[Array[MapStatus]]
   }
-}
-
-private[spark] object MapOutputTracker {
-  private val LOG_BASE = 1.1
 
   // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
   // any of the statuses is null (indicating a missing location due to a failed mapper),

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/9159d2d0/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 29968c2..aaab717 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -187,10 +187,14 @@ object SparkEnv extends Logging {
 
     // Have to assign trackerActor after initialization as MapOutputTrackerActor
     // requires the MapOutputTracker itself
-    val mapOutputTracker = new MapOutputTracker()
+    val mapOutputTracker =  if (isDriver) {
+      new MapOutputTrackerMaster()
+    } else {
+      new MapOutputTracker()
+    }
     mapOutputTracker.trackerActor = registerOrLookup(
       "MapOutputTracker",
-      new MapOutputTrackerActor(mapOutputTracker))
+      new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
 
     val shuffleFetcher = instantiateClass[ShuffleFetcher](
       "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/9159d2d0/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index d84f596..e58ff37 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -52,13 +52,14 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
 private[spark]
 class DAGScheduler(
     taskSched: TaskScheduler,
-    mapOutputTracker: MapOutputTracker,
+    mapOutputTracker: MapOutputTrackerMaster,
     blockManagerMaster: BlockManagerMaster,
     env: SparkEnv)
   extends Logging {
 
   def this(taskSched: TaskScheduler) {
-    this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+    this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+      SparkEnv.get.blockManager.master, SparkEnv.get)
   }
   taskSched.setDAGScheduler(this)
 

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/9159d2d0/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6013320..b7eb268 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
 
   test("master start and stop") {
     val actorSystem = ActorSystem("test")
-    val tracker = new MapOutputTracker()
-    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+    val tracker = new MapOutputTrackerMaster()
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
     tracker.stop()
   }
 
   test("master register and fetch") {
     val actorSystem = ActorSystem("test")
-    val tracker = new MapOutputTracker()
-    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+    val tracker = new MapOutputTrackerMaster()
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
     tracker.registerShuffle(10, 2)
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
 
   test("master register and unregister and fetch") {
     val actorSystem = ActorSystem("test")
-    val tracker = new MapOutputTracker()
-    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+    val tracker = new MapOutputTrackerMaster()
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
     tracker.registerShuffle(10, 2)
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
-    val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
-    val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
     tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
         Array(compressedSize1000, compressedSize1000, compressedSize1000)))
     tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
         Array(compressedSize10000, compressedSize1000, compressedSize1000)))
 
-    // As if we had two simulatenous fetch failures
+    // As if we had two simultaneous fetch failures
     tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
     tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
 
@@ -102,9 +100,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
     System.setProperty("spark.driver.port", boundPort.toString)    // Will be cleared by
LocalSparkContext
     System.setProperty("spark.hostPort", hostname + ":" + boundPort)
 
-    val masterTracker = new MapOutputTracker()
+    val masterTracker = new MapOutputTrackerMaster()
     masterTracker.trackerActor = actorSystem.actorOf(
-        Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+        Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
 
     val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
     val slaveTracker = new MapOutputTracker()

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/9159d2d0/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2a2f828..00f2fdd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.LocalSparkContext
-import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTrackerMaster
 import org.apache.spark.SparkContext
 import org.apache.spark.Partition
 import org.apache.spark.TaskContext
@@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     override def defaultParallelism() = 2
   }
 
-  var mapOutputTracker: MapOutputTracker = null
+  var mapOutputTracker: MapOutputTrackerMaster = null
   var scheduler: DAGScheduler = null
 
   /**
@@ -99,7 +99,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     taskSets.clear()
     cacheLocations.clear()
     results.clear()
-    mapOutputTracker = new MapOutputTracker()
+    mapOutputTracker = new MapOutputTrackerMaster()
     scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
{
       override def runLocally(job: ActiveJob) {
         // don't bother with the thread while unit testing


Mime
View raw message