spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject [09/69] [abbrv] [partial] Initial work to rename package to org.apache.spark
Date Sun, 01 Sep 2013 22:00:15 GMT
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
new file mode 100644
index 0000000..94df282
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable.{Map, HashMap}
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.LocalSparkContext
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.Partition
+import org.apache.spark.TaskContext
+import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
+import org.apache.spark.{FetchFailed, Success, TaskEndReason}
+import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
+
+import org.apache.spark.scheduler.cluster.Pool
+import org.apache.spark.scheduler.cluster.SchedulingMode
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+/**
+ * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
+ * rather than spawning an event loop thread as happens in the real code. They use EasyMock
+ * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
+ * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead
+ * host notifications are sent). In addition, tests may check for side effects on a non-mocked
+ * MapOutputTracker instance.
+ *
+ * Tests primarily consist of running DAGScheduler#processEvent and
+ * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
+ * and capturing the resulting TaskSets from the mock TaskScheduler.
+ */
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+  val taskScheduler = new TaskScheduler() {
+    override def rootPool: Pool = null
+    override def schedulingMode: SchedulingMode = SchedulingMode.NONE
+    override def start() = {}
+    override def stop() = {}
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+      taskSets += taskSet
+    }
+    override def setListener(listener: TaskSchedulerListener) = {}
+    override def defaultParallelism() = 2
+  }
+
+  var mapOutputTracker: MapOutputTracker = null
+  var scheduler: DAGScheduler = null
+
+  /**
+   * Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
+  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+  // stub out BlockManagerMaster.getLocations to use our cacheLocations
+  val blockManagerMaster = new BlockManagerMaster(null) {
+      override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+        blockIds.map { name =>
+          val pieces = name.split("_")
+          if (pieces(0) == "rdd") {
+            val key = pieces(1).toInt -> pieces(2).toInt
+            cacheLocations.getOrElse(key, Seq())
+          } else {
+            Seq()
+          }
+        }.toSeq
+      }
+      override def removeExecutor(execId: String) {
+        // don't need to propagate to the driver, which we don't have
+      }
+    }
+
+  /** The list of results that DAGScheduler has collected. */
+  val results = new HashMap[Int, Any]()
+  var failure: Exception = _
+  val listener = new JobListener() {
+    override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+    override def jobFailed(exception: Exception) = { failure = exception }
+  }
+
+  before {
+    sc = new SparkContext("local", "DAGSchedulerSuite")
+    taskSets.clear()
+    cacheLocations.clear()
+    results.clear()
+    mapOutputTracker = new MapOutputTracker()
+    scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+      override def runLocally(job: ActiveJob) {
+        // don't bother with the thread while unit testing
+        runLocallyWithinThread(job)
+      }
+    }
+  }
+
+  after {
+    scheduler.stop()
+  }
+
+  /**
+   * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
+   * This is a pair RDD type so it can always be used in ShuffleDependencies.
+   */
+  type MyRDD = RDD[(Int, Int)]
+
+  /**
+   * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+   * preferredLocations (if any) that are passed to them. They are deliberately not executable
+   * so we can test that DAGScheduler does not try to execute RDDs locally.
+   */
+  private def makeRdd(
+        numPartitions: Int,
+        dependencies: List[Dependency[_]],
+        locations: Seq[Seq[String]] = Nil
+      ): MyRDD = {
+    val maxPartition = numPartitions - 1
+    return new MyRDD(sc, dependencies) {
+      override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+        throw new RuntimeException("should not be reached")
+      override def getPartitions = (0 to maxPartition).map(i => new Partition {
+        override def index = i
+      }).toArray
+      override def getPreferredLocations(split: Partition): Seq[String] =
+        if (locations.isDefinedAt(split.index))
+          locations(split.index)
+        else
+          Nil
+      override def toString: String = "DAGSchedulerSuiteRDD " + id
+    }
+  }
+
+  /**
+   * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
+   * the scheduler not to exit.
+   *
+   * After processing the event, submit waiting stages as is done on most iterations of the
+   * DAGScheduler event loop.
+   */
+  private def runEvent(event: DAGSchedulerEvent) {
+    assert(!scheduler.processEvent(event))
+    scheduler.submitWaitingStages()
+  }
+
+  /**
+   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+   * below, we do not expect this function to ever be executed; instead, we will return results
+   * directly through CompletionEvents.
+   */
+  private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+     it.next.asInstanceOf[Tuple2[_, _]]._1
+
+  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+  private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+    assert(taskSet.tasks.size >= results.size)
+    for ((result, i) <- results.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
+      }
+    }
+  }
+
+  /** Sends the rdd to the scheduler for scheduling. */
+  private def submit(
+      rdd: RDD[_],
+      partitions: Array[Int],
+      func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+      allowLocal: Boolean = false,
+      listener: JobListener = listener) {
+    runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+  }
+
+  /** Sends TaskSetFailed to the scheduler. */
+  private def failed(taskSet: TaskSet, message: String) {
+    runEvent(TaskSetFailed(taskSet, message))
+  }
+
+  test("zero split job") {
+    val rdd = makeRdd(0, Nil)
+    var numResults = 0
+    val fakeListener = new JobListener() {
+      override def taskSucceeded(partition: Int, value: Any) = numResults += 1
+      override def jobFailed(exception: Exception) = throw exception
+    }
+    submit(rdd, Array(), listener = fakeListener)
+    assert(numResults === 0)
+  }
+
+  test("run trivial job") {
+    val rdd = makeRdd(1, Nil)
+    submit(rdd, Array(0))
+    complete(taskSets(0), List((Success, 42)))
+    assert(results === Map(0 -> 42))
+  }
+
+  test("local job") {
+    val rdd = new MyRDD(sc, Nil) {
+      override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+        Array(42 -> 0).iterator
+      override def getPartitions = Array( new Partition { override def index = 0 } )
+      override def getPreferredLocations(split: Partition) = Nil
+      override def toString = "DAGSchedulerSuite Local RDD"
+    }
+    runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+    assert(results === Map(0 -> 42))
+  }
+
+  test("run trivial job w/ dependency") {
+    val baseRdd = makeRdd(1, Nil)
+    val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+    submit(finalRdd, Array(0))
+    complete(taskSets(0), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
+  }
+
+  test("cache location preferences w/ dependency") {
+    val baseRdd = makeRdd(1, Nil)
+    val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
+    cacheLocations(baseRdd.id -> 0) =
+      Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+    submit(finalRdd, Array(0))
+    val taskSet = taskSets(0)
+    assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
+    complete(taskSet, Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
+  }
+
+  test("trivial job failure") {
+    submit(makeRdd(1, Nil), Array(0))
+    failed(taskSets(0), "some failure")
+    assert(failure.getMessage === "Job failed: some failure")
+  }
+
+  test("run trivial shuffle") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(1, List(shuffleDep))
+    submit(reduceRdd, Array(0))
+    complete(taskSets(0), Seq(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))))
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+    complete(taskSets(1), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
+  }
+
+  test("run trivial shuffle with fetch failure") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(2, List(shuffleDep))
+    submit(reduceRdd, Array(0, 1))
+    complete(taskSets(0), Seq(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))))
+    // the 2nd ResultTask failed
+    complete(taskSets(1), Seq(
+        (Success, 42),
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
+    // this will get called
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // ask the scheduler to try it again
+    scheduler.resubmitFailedStages()
+    // have the 2nd attempt pass
+    complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+    // we can see both result blocks now
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
+    complete(taskSets(3), Seq((Success, 43)))
+    assert(results === Map(0 -> 42, 1 -> 43))
+  }
+
+  test("ignore late map task completions") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(2, List(shuffleDep))
+    submit(reduceRdd, Array(0, 1))
+    // pretend we were told hostA went away
+    val oldEpoch = mapOutputTracker.getEpoch
+    runEvent(ExecutorLost("exec-hostA"))
+    val newEpoch = mapOutputTracker.getEpoch
+    assert(newEpoch > oldEpoch)
+    val noAccum = Map[Long, Any]()
+    val taskSet = taskSets(0)
+    // should be ignored for being too old
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+    // should work because it's a non-failed host
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
+    // should be ignored for being too old
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+    // should work because it's a new epoch
+    taskSet.tasks(1).epoch = newEpoch
+    runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+           Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+    complete(taskSets(1), Seq((Success, 42), (Success, 43)))
+    assert(results === Map(0 -> 42, 1 -> 43))
+  }
+
+  test("run trivial shuffle with out-of-band failure and retry") {
+    val shuffleMapRdd = makeRdd(2, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleId = shuffleDep.shuffleId
+    val reduceRdd = makeRdd(1, List(shuffleDep))
+    submit(reduceRdd, Array(0))
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // pretend we were told hostA went away
+    runEvent(ExecutorLost("exec-hostA"))
+    // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
+    // rather than marking it is as failed and waiting.
+    complete(taskSets(0), Seq(
+        (Success, makeMapStatus("hostA", 1)),
+       (Success, makeMapStatus("hostB", 1))))
+   // have hostC complete the resubmitted task
+   complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+   assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+          Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+   complete(taskSets(2), Seq((Success, 42)))
+   assert(results === Map(0 -> 42))
+ }
+
+ test("recursive shuffle failures") {
+    val shuffleOneRdd = makeRdd(2, Nil)
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+    val finalRdd = makeRdd(1, List(shuffleDepTwo))
+    submit(finalRdd, Array(0))
+    // have the first stage complete normally
+    complete(taskSets(0), Seq(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))))
+    // have the second stage complete normally
+    complete(taskSets(1), Seq(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostC", 1))))
+    // fail the third stage because hostA went down
+    complete(taskSets(2), Seq(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+    // TODO assert this:
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // have DAGScheduler try again
+    scheduler.resubmitFailedStages()
+    complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
+    complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
+    complete(taskSets(5), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
+  }
+
+  test("cached post-shuffle") {
+    val shuffleOneRdd = makeRdd(2, Nil)
+    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
+    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
+    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
+    val finalRdd = makeRdd(1, List(shuffleDepTwo))
+    submit(finalRdd, Array(0))
+    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
+    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
+    // complete stage 2
+    complete(taskSets(0), Seq(
+        (Success, makeMapStatus("hostA", 2)),
+        (Success, makeMapStatus("hostB", 2))))
+    // complete stage 1
+    complete(taskSets(1), Seq(
+        (Success, makeMapStatus("hostA", 1)),
+        (Success, makeMapStatus("hostB", 1))))
+    // pretend stage 0 failed because hostA went down
+    complete(taskSets(2), Seq(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+    // TODO assert this:
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
+    scheduler.resubmitFailedStages()
+    assertLocations(taskSets(3), Seq(Seq("hostD")))
+    // allow hostD to recover
+    complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
+    complete(taskSets(4), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
+  }
+
+  /**
+   * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
+   * Note that this checks only the host and not the executor ID.
+   */
+  private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) {
+    assert(hosts.size === taskSet.tasks.size)
+    for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) {
+      assert(taskLocs.map(_.host) === expectedLocs)
+    }
+  }
+
+  private def makeMapStatus(host: String, reduces: Int): MapStatus =
+   new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
+
+  private def makeBlockManagerId(host: String): BlockManagerId =
+    BlockManagerId("exec-" + host, host, 12345, 0)
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
new file mode 100644
index 0000000..f5b3e97
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.LinkedBlockingQueue
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+
+class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+  test("inner method") {
+    sc = new SparkContext("local", "joblogger")
+    val joblogger = new JobLogger {
+      def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
+      def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
+      def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
+      def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) 
+    }
+    type MyRDD = RDD[(Int, Int)]
+    def makeRdd(
+        numPartitions: Int,
+        dependencies: List[Dependency[_]]
+      ): MyRDD = {
+      val maxPartition = numPartitions - 1
+      return new MyRDD(sc, dependencies) {
+        override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+          throw new RuntimeException("should not be reached")
+        override def getPartitions = (0 to maxPartition).map(i => new Partition {
+          override def index = i
+        }).toArray
+      }
+    }
+    val jobID = 5
+    val parentRdd = makeRdd(4, Nil)
+    val shuffleDep = new ShuffleDependency(parentRdd, null)
+    val rootRdd = makeRdd(4, List(shuffleDep))
+    val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
+    val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
+    
+    joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
+    joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+    parentRdd.setName("MyRDD")
+    joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
+    joblogger.createLogWriterTest(jobID)
+    joblogger.getJobIDtoPrintWriter.size should be (1)
+    joblogger.buildJobDepTest(jobID, rootStage)
+    joblogger.getJobIDToStages.get(jobID).get.size should be (2)
+    joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
+    joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
+    joblogger.closeLogWriterTest(jobID)
+    joblogger.getStageIDToJobID.size should be (0)
+    joblogger.getJobIDToStages.size should be (0)
+    joblogger.getJobIDtoPrintWriter.size should be (0)
+  }
+  
+  test("inner variables") {
+    sc = new SparkContext("local[4]", "joblogger")
+    val joblogger = new JobLogger {
+      override protected def closeLogWriter(jobID: Int) = 
+        getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => 
+          fileWriter.close()
+        }
+    }
+    sc.addSparkListener(joblogger)
+    val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+    rdd.reduceByKey(_+_).collect()
+    
+    joblogger.getLogDir should be ("/tmp/spark")
+    joblogger.getJobIDtoPrintWriter.size should be (1)
+    joblogger.getStageIDToJobID.size should be (2)
+    joblogger.getStageIDToJobID.get(0) should be (Some(0))
+    joblogger.getStageIDToJobID.get(1) should be (Some(0))
+    joblogger.getJobIDToStages.size should be (1)
+  }
+  
+  
+  test("interface functions") {
+    sc = new SparkContext("local[4]", "joblogger")
+    val joblogger = new JobLogger {
+      var onTaskEndCount = 0
+      var onJobEndCount = 0 
+      var onJobStartCount = 0
+      var onStageCompletedCount = 0
+      var onStageSubmittedCount = 0
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd)  = onTaskEndCount += 1
+      override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
+      override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
+      override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
+      override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
+    }
+    sc.addSparkListener(joblogger)
+    val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+    rdd.reduceByKey(_+_).collect()
+    
+    joblogger.onJobStartCount should be (1)
+    joblogger.onJobEndCount should be (1)
+    joblogger.onTaskEndCount should be (8)
+    joblogger.onStageSubmittedCount should be (2)
+    joblogger.onStageCompletedCount should be (2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
new file mode 100644
index 0000000..aac7c20
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.scalatest.FunSuite
+import org.apache.spark.{SparkContext, LocalSparkContext}
+import scala.collection.mutable
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.SparkContext._
+
+/**
+ *
+ */
+
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+  test("local metrics") {
+    sc = new SparkContext("local[4]", "test")
+    val listener = new SaveStageInfo
+    sc.addSparkListener(listener)
+    sc.addSparkListener(new StatsReportListener)
+    //just to make sure some of the tasks take a noticeable amount of time
+    val w = {i:Int =>
+      if (i == 0)
+        Thread.sleep(100)
+      i
+    }
+
+    val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+    d.count
+    listener.stageInfos.size should be (1)
+
+    val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1")
+
+    val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2")
+
+    val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
+    d4.setName("A Cogroup")
+
+    d4.collectAsMap
+
+    listener.stageInfos.size should be (4)
+    listener.stageInfos.foreach {stageInfo =>
+      //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
+      checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
+      checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
+      checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
+      if (stageInfo.stage.rdd.name == d4.name) {
+        checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
+      }
+
+        stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
+        taskMetrics.resultSize should be > (0l)
+        if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+          taskMetrics.shuffleWriteMetrics should be ('defined)
+          taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
+        }
+        if (stageInfo.stage.rdd.name == d4.name) {
+          taskMetrics.shuffleReadMetrics should be ('defined)
+          val sm = taskMetrics.shuffleReadMetrics.get
+          sm.totalBlocksFetched should be > (0)
+          sm.localBlocksFetched should be > (0)
+          sm.remoteBlocksFetched should be (0)
+          sm.remoteBytesRead should be (0l)
+          sm.remoteFetchTime should be (0l)
+        }
+      }
+    }
+  }
+
+  def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+    assert(m.sum / m.size.toDouble > 0.0, msg)
+  }
+
+  def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
+    val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
+    !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+  }
+
+  class SaveStageInfo extends SparkListener {
+    val stageInfos = mutable.Buffer[StageInfo]()
+    override def onStageCompleted(stage: StageCompleted) {
+      stageInfos += stage.stageInfo
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
new file mode 100644
index 0000000..0347cc0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.TaskContext
+import org.apache.spark.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.Partition
+import org.apache.spark.LocalSparkContext
+
+class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+  test("Calls executeOnCompleteCallbacks after failure") {
+    var completed = false
+    sc = new SparkContext("local", "test")
+    val rdd = new RDD[String](sc, List()) {
+      override def getPartitions = Array[Partition](StubPartition(0))
+      override def compute(split: Partition, context: TaskContext) = {
+        context.addOnCompleteCallback(() => completed = true)
+        sys.error("failed")
+      }
+    }
+    val func = (c: TaskContext, i: Iterator[String]) => i.next
+    val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+    intercept[RuntimeException] {
+      task.run(0)
+    }
+    assert(completed === true)
+  }
+
+  case class StubPartition(val index: Int) extends Partition
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
new file mode 100644
index 0000000..92ad9f0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark._
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+
+import java.util.Properties
+
+class FakeTaskSetManager(
+    initPriority: Int,
+    initStageId: Int,
+    initNumTasks: Int,
+    clusterScheduler: ClusterScheduler,
+    taskSet: TaskSet)
+  extends ClusterTaskSetManager(clusterScheduler, taskSet) {
+
+  parent = null
+  weight = 1
+  minShare = 2
+  runningTasks = 0
+  priority = initPriority
+  stageId = initStageId
+  name = "TaskSet_"+stageId
+  override val numTasks = initNumTasks
+  tasksFinished = 0
+
+  override def increaseRunningTasks(taskNum: Int) {
+    runningTasks += taskNum
+    if (parent != null) {
+      parent.increaseRunningTasks(taskNum)
+    }
+  }
+
+  override def decreaseRunningTasks(taskNum: Int) {
+    runningTasks -= taskNum
+    if (parent != null) {
+      parent.decreaseRunningTasks(taskNum)
+    }
+  }
+
+  override def addSchedulable(schedulable: Schedulable) {
+  }
+
+  override def removeSchedulable(schedulable: Schedulable) {
+  }
+
+  override def getSchedulableByName(name: String): Schedulable = {
+    return null
+  }
+
+  override def executorLost(executorId: String, host: String): Unit = {
+  }
+
+  override def resourceOffer(
+      execId: String,
+      host: String,
+      availableCpus: Int,
+      maxLocality: TaskLocality.TaskLocality)
+    : Option[TaskDescription] =
+  {
+    if (tasksFinished + runningTasks < numTasks) {
+      increaseRunningTasks(1)
+      return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
+    }
+    return None
+  }
+
+  override def checkSpeculatableTasks(): Boolean = {
+    return true
+  }
+
+  def taskFinished() {
+    decreaseRunningTasks(1)
+    tasksFinished +=1
+    if (tasksFinished == numTasks) {
+      parent.removeSchedulable(this)
+    }
+  }
+
+  def abort() {
+    decreaseRunningTasks(runningTasks)
+    parent.removeSchedulable(this)
+  }
+}
+
+class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
+
+  def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = {
+    new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
+  }
+
+  def resourceOffer(rootPool: Pool): Int = {
+    val taskSetQueue = rootPool.getSortedTaskSetQueue()
+    /* Just for Test*/
+    for (manager <- taskSetQueue) {
+       logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+    }
+    for (taskSet <- taskSetQueue) {
+      taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match {
+        case Some(task) =>
+          return taskSet.stageId
+        case None => {}
+      }
+    }
+    -1
+  }
+
+  def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) {
+    assert(resourceOffer(rootPool) === expectedTaskSetId)
+  }
+
+  test("FIFO Scheduler Test") {
+    sc = new SparkContext("local", "ClusterSchedulerSuite")
+    val clusterScheduler = new ClusterScheduler(sc)
+    var tasks = ArrayBuffer[Task[_]]()
+    val task = new FakeTask(0)
+    tasks += task
+    val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+    val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
+    val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
+    schedulableBuilder.buildPools()
+
+    val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
+    val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
+    val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
+    schedulableBuilder.addTaskSetManager(taskSetManager0, null)
+    schedulableBuilder.addTaskSetManager(taskSetManager1, null)
+    schedulableBuilder.addTaskSetManager(taskSetManager2, null)
+
+    checkTaskSetId(rootPool, 0)
+    resourceOffer(rootPool)
+    checkTaskSetId(rootPool, 1)
+    resourceOffer(rootPool)
+    taskSetManager1.abort()
+    checkTaskSetId(rootPool, 2)
+  }
+
+  test("Fair Scheduler Test") {
+    sc = new SparkContext("local", "ClusterSchedulerSuite")
+    val clusterScheduler = new ClusterScheduler(sc)
+    var tasks = ArrayBuffer[Task[_]]()
+    val task = new FakeTask(0)
+    tasks += task
+    val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+    System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+    val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+    val schedulableBuilder = new FairSchedulableBuilder(rootPool)
+    schedulableBuilder.buildPools()
+
+    assert(rootPool.getSchedulableByName("default") != null)
+    assert(rootPool.getSchedulableByName("1") != null)
+    assert(rootPool.getSchedulableByName("2") != null)
+    assert(rootPool.getSchedulableByName("3") != null)
+    assert(rootPool.getSchedulableByName("1").minShare === 2)
+    assert(rootPool.getSchedulableByName("1").weight === 1)
+    assert(rootPool.getSchedulableByName("2").minShare === 3)
+    assert(rootPool.getSchedulableByName("2").weight === 1)
+    assert(rootPool.getSchedulableByName("3").minShare === 2)
+    assert(rootPool.getSchedulableByName("3").weight === 1)
+
+    val properties1 = new Properties()
+    properties1.setProperty("spark.scheduler.cluster.fair.pool","1")
+    val properties2 = new Properties()
+    properties2.setProperty("spark.scheduler.cluster.fair.pool","2")
+
+    val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
+    val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
+    val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
+    schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
+    schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
+    schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
+
+    val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
+    val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
+    schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
+    schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
+
+    checkTaskSetId(rootPool, 0)
+    checkTaskSetId(rootPool, 3)
+    checkTaskSetId(rootPool, 3)
+    checkTaskSetId(rootPool, 1)
+    checkTaskSetId(rootPool, 4)
+    checkTaskSetId(rootPool, 2)
+    checkTaskSetId(rootPool, 2)
+    checkTaskSetId(rootPool, 4)
+
+    taskSetManager12.taskFinished()
+    assert(rootPool.getSchedulableByName("1").runningTasks === 3)
+    taskSetManager24.abort()
+    assert(rootPool.getSchedulableByName("2").runningTasks === 2)
+  }
+
+  test("Nested Pool Test") {
+    sc = new SparkContext("local", "ClusterSchedulerSuite")
+    val clusterScheduler = new ClusterScheduler(sc)
+    var tasks = ArrayBuffer[Task[_]]()
+    val task = new FakeTask(0)
+    tasks += task
+    val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+    val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+    val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
+    val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1)
+    rootPool.addSchedulable(pool0)
+    rootPool.addSchedulable(pool1)
+
+    val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2)
+    val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1)
+    pool0.addSchedulable(pool00)
+    pool0.addSchedulable(pool01)
+
+    val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2)
+    val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1)
+    pool1.addSchedulable(pool10)
+    pool1.addSchedulable(pool11)
+
+    val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
+    val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
+    pool00.addSchedulable(taskSetManager000)
+    pool00.addSchedulable(taskSetManager001)
+
+    val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
+    val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
+    pool01.addSchedulable(taskSetManager010)
+    pool01.addSchedulable(taskSetManager011)
+
+    val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
+    val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
+    pool10.addSchedulable(taskSetManager100)
+    pool10.addSchedulable(taskSetManager101)
+
+    val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
+    val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
+    pool11.addSchedulable(taskSetManager110)
+    pool11.addSchedulable(taskSetManager111)
+
+    checkTaskSetId(rootPool, 0)
+    checkTaskSetId(rootPool, 4)
+    checkTaskSetId(rootPool, 6)
+    checkTaskSetId(rootPool, 2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
new file mode 100644
index 0000000..a4f63ba
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.scheduler._
+import org.apache.spark.executor.TaskMetrics
+import java.nio.ByteBuffer
+import org.apache.spark.util.FakeClock
+
+/**
+ * A mock ClusterScheduler implementation that just remembers information about tasks started and
+ * feedback received from the TaskSetManagers. Note that it's important to initialize this with
+ * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
+ * to work, and these are required for locality in ClusterTaskSetManager.
+ */
+class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
+  extends ClusterScheduler(sc)
+{
+  val startedTasks = new ArrayBuffer[Long]
+  val endedTasks = new mutable.HashMap[Long, TaskEndReason]
+  val finishedManagers = new ArrayBuffer[TaskSetManager]
+
+  val executors = new mutable.HashMap[String, String] ++ liveExecutors
+
+  listener = new TaskSchedulerListener {
+    def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+      startedTasks += taskInfo.index
+    }
+
+    def taskEnded(
+        task: Task[_],
+        reason: TaskEndReason,
+        result: Any,
+        accumUpdates: mutable.Map[Long, Any],
+        taskInfo: TaskInfo,
+        taskMetrics: TaskMetrics)
+    {
+      endedTasks(taskInfo.index) = reason
+    }
+
+    def executorGained(execId: String, host: String) {}
+
+    def executorLost(execId: String) {}
+
+    def taskSetFailed(taskSet: TaskSet, reason: String) {}
+  }
+
+  def removeExecutor(execId: String): Unit = executors -= execId
+
+  override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager
+
+  override def isExecutorAlive(execId: String): Boolean = executors.contains(execId)
+
+  override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)
+}
+
+class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
+  import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL}
+
+  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+  test("TaskSet with no preferences") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+    val taskSet = createTaskSet(1)
+    val manager = new ClusterTaskSetManager(sched, taskSet)
+
+    // Offer a host with no CPUs
+    assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None)
+
+    // Offer a host with process-local as the constraint; this should work because the TaskSet
+    // above won't have any locality preferences
+    val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL)
+    assert(taskOption.isDefined)
+    val task = taskOption.get
+    assert(task.executorId === "exec1")
+    assert(sched.startedTasks.contains(0))
+
+    // Re-offer the host -- now we should get no more tasks
+    assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None)
+
+    // Tell it the task has finished
+    manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+    assert(sched.endedTasks(0) === Success)
+    assert(sched.finishedManagers.contains(manager))
+  }
+
+  test("multiple offers with no preferences") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+    val taskSet = createTaskSet(3)
+    val manager = new ClusterTaskSetManager(sched, taskSet)
+
+    // First three offers should all find tasks
+    for (i <- 0 until 3) {
+      val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL)
+      assert(taskOption.isDefined)
+      val task = taskOption.get
+      assert(task.executorId === "exec1")
+    }
+    assert(sched.startedTasks.toSet === Set(0, 1, 2))
+
+    // Re-offer the host -- now we should get no more tasks
+    assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+    // Finish the first two tasks
+    manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+    manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1))
+    assert(sched.endedTasks(0) === Success)
+    assert(sched.endedTasks(1) === Success)
+    assert(!sched.finishedManagers.contains(manager))
+
+    // Finish the last task
+    manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2))
+    assert(sched.endedTasks(2) === Success)
+    assert(sched.finishedManagers.contains(manager))
+  }
+
+  test("basic delay scheduling") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+    val taskSet = createTaskSet(4,
+      Seq(TaskLocation("host1", "exec1")),
+      Seq(TaskLocation("host2", "exec2")),
+      Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")),
+      Seq()   // Last task has no locality prefs
+    )
+    val clock = new FakeClock
+    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+    // First offer host1, exec1: first task should be chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+    // Offer host1, exec1 again: the last task, which has no prefs, should be chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3)
+
+    // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+    clock.advance(LOCALITY_WAIT)
+
+    // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+    // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2
+    assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2)
+
+    // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None)
+
+    // Offer host1, exec1 again, at ANY level: nothing should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+    clock.advance(LOCALITY_WAIT)
+
+    // Offer host1, exec1 again, at ANY level: task 1 should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+    // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+  }
+
+  test("delay scheduling with fallback") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeClusterScheduler(sc,
+      ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
+    val taskSet = createTaskSet(5,
+      Seq(TaskLocation("host1")),
+      Seq(TaskLocation("host2")),
+      Seq(TaskLocation("host2")),
+      Seq(TaskLocation("host3")),
+      Seq(TaskLocation("host2"))
+    )
+    val clock = new FakeClock
+    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+    // First offer host1: first task should be chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+    // Offer host1 again: nothing should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+    clock.advance(LOCALITY_WAIT)
+
+    // Offer host1 again: second task (on host2) should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+    // Offer host1 again: third task (on host2) should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2)
+
+    // Offer host2: fifth task (also on host2) should get chosen
+    assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4)
+
+    // Now that we've launched a local task, we should no longer launch the task for host3
+    assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
+
+    clock.advance(LOCALITY_WAIT)
+
+    // After another delay, we can go ahead and launch that task non-locally
+    assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3)
+  }
+
+  test("delay scheduling with failed hosts") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+    val taskSet = createTaskSet(3,
+      Seq(TaskLocation("host1")),
+      Seq(TaskLocation("host2")),
+      Seq(TaskLocation("host3"))
+    )
+    val clock = new FakeClock
+    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+    // First offer host1: first task should be chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+    // Offer host1 again: third task should be chosen immediately because host3 is not up
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2)
+
+    // After this, nothing should get chosen
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+    // Now mark host2 as dead
+    sched.removeExecutor("exec2")
+    manager.executorLost("exec2", "host2")
+
+    // Task 1 should immediately be launched on host1 because its original host is gone
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+    // Now that all tasks have launched, nothing new should be launched anywhere else
+    assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+    assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
+  }
+
+  /**
+   * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
+   * locations for each task (given as varargs) if this sequence is not empty.
+   */
+  def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+    if (prefLocs.size != 0 && prefLocs.size != numTasks) {
+      throw new IllegalArgumentException("Wrong number of task locations")
+    }
+    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
+      new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
+    }
+    new TaskSet(tasks, 0, 0, 0, null)
+  }
+
+  def createTaskResult(id: Int): ByteBuffer = {
+    ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics)))
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
new file mode 100644
index 0000000..2f12aae
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.scheduler.{TaskLocation, Task}
+
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
+  override def run(attemptId: Long): Int = 0
+
+  override def preferredLocations: Seq[TaskLocation] = prefLocs
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/46eecd11/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
new file mode 100644
index 0000000..111340a
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark._
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+import java.util.concurrent.Semaphore
+import java.util.concurrent.CountDownLatch
+import java.util.Properties
+
+class Lock() {
+  var finished = false
+  def jobWait() = {
+    synchronized {
+      while(!finished) {
+        this.wait()
+      }
+    }
+  }
+
+  def jobFinished() = {
+    synchronized {
+      finished = true
+      this.notifyAll()
+    }
+  }
+}
+
+object TaskThreadInfo {
+  val threadToLock = HashMap[Int, Lock]()
+  val threadToRunning = HashMap[Int, Boolean]()
+  val threadToStarted = HashMap[Int, CountDownLatch]()
+}
+
+/*
+ * 1. each thread contains one job.
+ * 2. each job contains one stage.
+ * 3. each stage only contains one task.
+ * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
+ *    it will get cpu core resource, and will wait to finished after user manually
+ *    release "Lock" and then cluster will contain another free cpu cores.
+ * 5. each task(pending) must use "sleep" to  make sure it has been added to taskSetManager queue,
+ *    thus it will be scheduled later when cluster has free cpu cores.
+ */
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+
+  def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
+
+    TaskThreadInfo.threadToRunning(threadIndex) = false
+    val nums = sc.parallelize(threadIndex to threadIndex, 1)
+    TaskThreadInfo.threadToLock(threadIndex) = new Lock()
+    TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
+    new Thread {
+      if (poolName != null) {
+        sc.setLocalProperty("spark.scheduler.cluster.fair.pool", poolName)
+      }
+      override def run() {
+        val ans = nums.map(number => {
+          TaskThreadInfo.threadToRunning(number) = true
+          TaskThreadInfo.threadToStarted(number).countDown()
+          TaskThreadInfo.threadToLock(number).jobWait()
+          TaskThreadInfo.threadToRunning(number) = false
+          number
+        }).collect()
+        assert(ans.toList === List(threadIndex))
+        sem.release()
+      }
+    }.start()
+  }
+
+  test("Local FIFO scheduler end-to-end test") {
+    System.setProperty("spark.cluster.schedulingmode", "FIFO")
+    sc = new SparkContext("local[4]", "test")
+    val sem = new Semaphore(0)
+
+    createThread(1,null,sc,sem)
+    TaskThreadInfo.threadToStarted(1).await()
+    createThread(2,null,sc,sem)
+    TaskThreadInfo.threadToStarted(2).await()
+    createThread(3,null,sc,sem)
+    TaskThreadInfo.threadToStarted(3).await()
+    createThread(4,null,sc,sem)
+    TaskThreadInfo.threadToStarted(4).await()
+    // thread 5 and 6 (stage pending)must meet following two points
+    // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
+    //    queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
+    // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
+    // So I just use "sleep" 1s here for each thread.
+    // TODO: any better solution?
+    createThread(5,null,sc,sem)
+    Thread.sleep(1000)
+    createThread(6,null,sc,sem)
+    Thread.sleep(1000)
+
+    assert(TaskThreadInfo.threadToRunning(1) === true)
+    assert(TaskThreadInfo.threadToRunning(2) === true)
+    assert(TaskThreadInfo.threadToRunning(3) === true)
+    assert(TaskThreadInfo.threadToRunning(4) === true)
+    assert(TaskThreadInfo.threadToRunning(5) === false)
+    assert(TaskThreadInfo.threadToRunning(6) === false)
+
+    TaskThreadInfo.threadToLock(1).jobFinished()
+    TaskThreadInfo.threadToStarted(5).await()
+
+    assert(TaskThreadInfo.threadToRunning(1) === false)
+    assert(TaskThreadInfo.threadToRunning(2) === true)
+    assert(TaskThreadInfo.threadToRunning(3) === true)
+    assert(TaskThreadInfo.threadToRunning(4) === true)
+    assert(TaskThreadInfo.threadToRunning(5) === true)
+    assert(TaskThreadInfo.threadToRunning(6) === false)
+
+    TaskThreadInfo.threadToLock(3).jobFinished()
+    TaskThreadInfo.threadToStarted(6).await()
+
+    assert(TaskThreadInfo.threadToRunning(1) === false)
+    assert(TaskThreadInfo.threadToRunning(2) === true)
+    assert(TaskThreadInfo.threadToRunning(3) === false)
+    assert(TaskThreadInfo.threadToRunning(4) === true)
+    assert(TaskThreadInfo.threadToRunning(5) === true)
+    assert(TaskThreadInfo.threadToRunning(6) === true)
+
+    TaskThreadInfo.threadToLock(2).jobFinished()
+    TaskThreadInfo.threadToLock(4).jobFinished()
+    TaskThreadInfo.threadToLock(5).jobFinished()
+    TaskThreadInfo.threadToLock(6).jobFinished()
+    sem.acquire(6)
+  }
+
+  test("Local fair scheduler end-to-end test") {
+    sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+    val sem = new Semaphore(0)
+    System.setProperty("spark.cluster.schedulingmode", "FAIR")
+    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+    System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+
+    createThread(10,"1",sc,sem)
+    TaskThreadInfo.threadToStarted(10).await()
+    createThread(20,"2",sc,sem)
+    TaskThreadInfo.threadToStarted(20).await()
+    createThread(30,"3",sc,sem)
+    TaskThreadInfo.threadToStarted(30).await()
+
+    assert(TaskThreadInfo.threadToRunning(10) === true)
+    assert(TaskThreadInfo.threadToRunning(20) === true)
+    assert(TaskThreadInfo.threadToRunning(30) === true)
+
+    createThread(11,"1",sc,sem)
+    TaskThreadInfo.threadToStarted(11).await()
+    createThread(21,"2",sc,sem)
+    TaskThreadInfo.threadToStarted(21).await()
+    createThread(31,"3",sc,sem)
+    TaskThreadInfo.threadToStarted(31).await()
+
+    assert(TaskThreadInfo.threadToRunning(11) === true)
+    assert(TaskThreadInfo.threadToRunning(21) === true)
+    assert(TaskThreadInfo.threadToRunning(31) === true)
+
+    createThread(12,"1",sc,sem)
+    TaskThreadInfo.threadToStarted(12).await()
+    createThread(22,"2",sc,sem)
+    TaskThreadInfo.threadToStarted(22).await()
+    createThread(32,"3",sc,sem)
+
+    assert(TaskThreadInfo.threadToRunning(12) === true)
+    assert(TaskThreadInfo.threadToRunning(22) === true)
+    assert(TaskThreadInfo.threadToRunning(32) === false)
+
+    TaskThreadInfo.threadToLock(10).jobFinished()
+    TaskThreadInfo.threadToStarted(32).await()
+
+    assert(TaskThreadInfo.threadToRunning(32) === true)
+
+    //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
+    //   queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
+    //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
+    createThread(23,"2",sc,sem)
+    createThread(33,"3",sc,sem)
+    Thread.sleep(1000)
+
+    TaskThreadInfo.threadToLock(11).jobFinished()
+    TaskThreadInfo.threadToStarted(23).await()
+
+    assert(TaskThreadInfo.threadToRunning(23) === true)
+    assert(TaskThreadInfo.threadToRunning(33) === false)
+
+    TaskThreadInfo.threadToLock(12).jobFinished()
+    TaskThreadInfo.threadToStarted(33).await()
+
+    assert(TaskThreadInfo.threadToRunning(33) === true)
+
+    TaskThreadInfo.threadToLock(20).jobFinished()
+    TaskThreadInfo.threadToLock(21).jobFinished()
+    TaskThreadInfo.threadToLock(22).jobFinished()
+    TaskThreadInfo.threadToLock(23).jobFinished()
+    TaskThreadInfo.threadToLock(30).jobFinished()
+    TaskThreadInfo.threadToLock(31).jobFinished()
+    TaskThreadInfo.threadToLock(32).jobFinished()
+    TaskThreadInfo.threadToLock(33).jobFinished()
+
+    sem.acquire(11)
+  }
+}


Mime
View raw message