spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject [02/20] git commit: Deduplicate Local and Cluster schedulers.
Date Wed, 25 Dec 2013 00:35:33 GMT
Deduplicate Local and Cluster schedulers.

The code in LocalScheduler/LocalTaskSetManager was nearly identical
to the code in ClusterScheduler/ClusterTaskSetManager. The redundancy
made making updating the schedulers unnecessarily painful and error-
prone. This commit combines the two into a single TaskScheduler/
TaskSetManager.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/5e91495f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/5e91495f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/5e91495f

Branch: refs/heads/master
Commit: 5e91495f5c718c837b5a5af2268f6faad00d357f
Parents: dc9ce16
Author: Kay Ousterhout <kayousterhout@gmail.com>
Authored: Wed Oct 30 17:07:24 2013 -0700
Committer: Kay Ousterhout <kayousterhout@gmail.com>
Committed: Wed Oct 30 18:48:34 2013 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/SparkContext.scala   |  34 +-
 .../spark/scheduler/ExecutorLossReason.scala    |  38 +
 .../spark/scheduler/SchedulerBackend.scala      |  37 +
 .../spark/scheduler/TaskResultGetter.scala      | 107 +++
 .../apache/spark/scheduler/TaskScheduler.scala  | 480 ++++++++++++-
 .../apache/spark/scheduler/TaskSetManager.scala | 688 +++++++++++++++++-
 .../apache/spark/scheduler/WorkerOffer.scala    |  24 +
 .../scheduler/cluster/ClusterScheduler.scala    | 486 -------------
 .../cluster/ClusterTaskSetManager.scala         | 703 -------------------
 .../cluster/CoarseGrainedSchedulerBackend.scala |   5 +-
 .../scheduler/cluster/ExecutorLossReason.scala  |  38 -
 .../scheduler/cluster/SchedulerBackend.scala    |  37 -
 .../cluster/SimrSchedulerBackend.scala          |   4 +-
 .../cluster/SparkDeploySchedulerBackend.scala   |   6 +-
 .../scheduler/cluster/TaskResultGetter.scala    | 108 ---
 .../spark/scheduler/cluster/WorkerOffer.scala   |  24 -
 .../mesos/CoarseMesosSchedulerBackend.scala     |   5 +-
 .../cluster/mesos/MesosSchedulerBackend.scala   |   9 +-
 .../spark/scheduler/local/LocalBackend.scala    |  73 ++
 .../spark/scheduler/local/LocalScheduler.scala  | 219 ------
 .../scheduler/local/LocalTaskSetManager.scala   | 191 -----
 .../scala/org/apache/spark/FailureSuite.scala   |  20 +-
 .../spark/scheduler/SparkListenerSuite.scala    |  19 +-
 .../cluster/TaskResultGetterSuite.scala         |   4 +-
 .../scheduler/local/LocalSchedulerSuite.scala   | 227 ------
 .../cluster/YarnClusterScheduler.scala          |  10 +-
 26 files changed, 1472 insertions(+), 2124 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index ade75e2..1850436 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -56,10 +56,9 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
 import org.apache.spark.rdd._
 import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
-  SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
+  SparkDeploySchedulerBackend, SimrSchedulerBackend}
 import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
-import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.scheduler.local.LocalBackend
 import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
 import org.apache.spark.ui.SparkUI
 import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
@@ -149,8 +148,6 @@ class SparkContext(
   private[spark] var taskScheduler: TaskScheduler = {
     // Regular expression used for local[N] master format
     val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
-    // Regular expression for local[N, maxRetries], used in tests with failing tasks
-    val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
     // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
     val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
     // Regular expression for connecting to Spark deploy clusters
@@ -162,23 +159,26 @@ class SparkContext(
 
     master match {
       case "local" =>
-        new LocalScheduler(1, 0, this)
+        val scheduler = new TaskScheduler(this)
+        val backend = new LocalBackend(scheduler, 1) 
+        scheduler.initialize(backend)
+        scheduler
 
       case LOCAL_N_REGEX(threads) =>
-        new LocalScheduler(threads.toInt, 0, this)
-
-      case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
-        new LocalScheduler(threads.toInt, maxFailures.toInt, this)
+        val scheduler = new TaskScheduler(this)
+        val backend = new LocalBackend(scheduler, threads.toInt) 
+        scheduler.initialize(backend)
+        scheduler
 
       case SPARK_REGEX(sparkUrl) =>
-        val scheduler = new ClusterScheduler(this)
+        val scheduler = new TaskScheduler(this)
         val masterUrls = sparkUrl.split(",").map("spark://" + _)
         val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
         scheduler.initialize(backend)
         scheduler
 
       case SIMR_REGEX(simrUrl) =>
-        val scheduler = new ClusterScheduler(this)
+        val scheduler = new TaskScheduler(this)
         val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
         scheduler.initialize(backend)
         scheduler
@@ -192,7 +192,7 @@ class SparkContext(
               memoryPerSlaveInt, SparkContext.executorMemoryRequested))
         }
 
-        val scheduler = new ClusterScheduler(this)
+        val scheduler = new TaskScheduler(this)
         val localCluster = new LocalSparkCluster(
           numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
         val masterUrls = localCluster.start()
@@ -207,7 +207,7 @@ class SparkContext(
         val scheduler = try {
           val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
           val cons = clazz.getConstructor(classOf[SparkContext])
-          cons.newInstance(this).asInstanceOf[ClusterScheduler]
+          cons.newInstance(this).asInstanceOf[TaskScheduler]
         } catch {
           // TODO: Enumerate the exact reasons why it can fail
           // But irrespective of it, it means we cannot proceed !
@@ -221,7 +221,7 @@ class SparkContext(
 
       case MESOS_REGEX(mesosUrl) =>
         MesosNativeLibrary.load()
-        val scheduler = new ClusterScheduler(this)
+        val scheduler = new TaskScheduler(this)
         val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
         val backend = if (coarseGrained) {
           new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
@@ -593,9 +593,7 @@ class SparkContext(
     }
     addedFiles(key) = System.currentTimeMillis
 
-    // Fetch the file locally in case a job is executed locally.
-    // Jobs that run through LocalScheduler will already fetch the required dependencies,
-    // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+    // Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
     Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
 
     logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
new file mode 100644
index 0000000..2bc43a9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.executor.ExecutorExitCode
+
+/**
+ * Represents an explanation for a executor or whole slave failing or exiting.
+ */
+private[spark]
+class ExecutorLossReason(val message: String) {
+  override def toString: String = message
+}
+
+private[spark]
+case class ExecutorExited(val exitCode: Int)
+  extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
+}
+
+private[spark]
+case class SlaveLost(_message: String = "Slave lost")
+  extends ExecutorLossReason(_message) {
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
new file mode 100644
index 0000000..1f0839a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.SparkContext
+
+/**
+ * A backend interface for scheduling systems that allows plugging in different ones under
+ * TaskScheduler. We assume a Mesos-like model where the application gets resource offers as
+ * machines become available and can launch tasks on them.
+ */
+private[spark] trait SchedulerBackend {
+  def start(): Unit
+  def stop(): Unit
+  def reviveOffers(): Unit
+  def defaultParallelism(): Int
+
+  def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
+
+  // Memory used by each executor (in megabytes)
+  protected val executorMemory: Int = SparkContext.executorMemoryRequested
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
new file mode 100644
index 0000000..5408fa7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.nio.ByteBuffer
+import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.Utils
+
+/**
+ * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
+ */
+private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler)
+  extends Logging {
+  private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
+  private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
+    THREADS, "Result resolver thread")
+
+  protected val serializer = new ThreadLocal[SerializerInstance] {
+    override def initialValue(): SerializerInstance = {
+      return sparkEnv.closureSerializer.newInstance()
+    }
+  }
+
+  def enqueueSuccessfulTask(
+    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
+    getTaskResultExecutor.execute(new Runnable {
+      override def run() {
+        try {
+          val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+            case directResult: DirectTaskResult[_] => directResult
+            case IndirectTaskResult(blockId) =>
+              logDebug("Fetching indirect task result for TID %s".format(tid))
+              scheduler.handleTaskGettingResult(taskSetManager, tid)
+              val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
+              if (!serializedTaskResult.isDefined) {
+                /* We won't be able to get the task result if the machine that ran the task failed
+                 * between when the task ended and when we tried to fetch the result, or if the
+                 * block manager had to flush the result. */
+                scheduler.handleFailedTask(
+                  taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+                return
+              }
+              val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
+                serializedTaskResult.get)
+              sparkEnv.blockManager.master.removeBlock(blockId)
+              deserializedResult
+          }
+          result.metrics.resultSize = serializedData.limit()
+          scheduler.handleSuccessfulTask(taskSetManager, tid, result)
+        } catch {
+          case cnf: ClassNotFoundException =>
+            val loader = Thread.currentThread.getContextClassLoader
+            taskSetManager.abort("ClassNotFound with classloader: " + loader)
+          case ex =>
+            taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+        }
+      }
+    })
+  }
+
+  def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
+    serializedData: ByteBuffer) {
+    var reason: Option[TaskEndReason] = None
+    getTaskResultExecutor.execute(new Runnable {
+      override def run() {
+        try {
+          if (serializedData != null && serializedData.limit() > 0) {
+            reason = Some(serializer.get().deserialize[TaskEndReason](
+              serializedData, getClass.getClassLoader))
+          }
+        } catch {
+          case cnd: ClassNotFoundException =>
+            // Log an error but keep going here -- the task failed, so not catastropic if we can't
+            // deserialize the reason.
+            val loader = Thread.currentThread.getContextClassLoader
+            logError(
+              "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
+          case ex => {}
+        }
+        scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
+      }
+    })
+  }
+
+  def stop() {
+    getTaskResultExecutor.shutdownNow()
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 10e0478..3f694dd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -17,39 +17,477 @@
 
 package org.apache.spark.scheduler
 
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+import java.util.{TimerTask, Timer}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 
 /**
- * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
- * Each TaskScheduler schedulers task for a single SparkContext.
- * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
- * and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler.
+ * Schedules tasks for a single SparkContext. Receives a set of tasks from the DAGScheduler for
+ * each stage, and is responsible for sending tasks to executors, running them, retrying if there
+ * are failures, and mitigating stragglers.  Returns events to the DAGScheduler.
+ * 
+ * Clients should first call initialize() and start(), then submit task sets through the
+ * runTasks method.
+ *
+ * This class can work with multiple types of clusters by acting through a SchedulerBackend.
+ * It can also work with a local setup by using a LocalBackend and setting isLocal to true.
+ * It handles common logic, like determining a scheduling order across jobs, waking up to launch
+ * speculative tasks, etc.
+ *
+ * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
+ * threads, so it needs locks in public API methods to maintain its state. In addition, some
+ * SchedulerBackends sycnchronize on themselves when they want to send events here, and then
+ * acquire a lock on us, so we need to make sure that we don't try to lock the backend while
+ * we are holding a lock on ourselves.
  */
-private[spark] trait TaskScheduler {
+private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = false) extends Logging {
+  // How often to check for speculative tasks
+  val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
+  // Threshold above which we warn user initial TaskSet may be starved
+  val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+
+  // TaskSetManagers are not thread safe, so any access to one should be synchronized
+  // on this class.
+  val activeTaskSets = new HashMap[String, TaskSetManager]
+
+  val taskIdToTaskSetId = new HashMap[Long, String]
+  val taskIdToExecutorId = new HashMap[Long, String]
+  val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+  @volatile private var hasReceivedTask = false
+  @volatile private var hasLaunchedTask = false
+  private val starvationTimer = new Timer(true)
+
+  // Incrementing task IDs
+  val nextTaskId = new AtomicLong(0)
+
+  // Which executor IDs we have executors on
+  val activeExecutorIds = new HashSet[String]
+
+  // The set of executors we have on each host; this is used to compute hostsAlive, which
+  // in turn is used to decide when we can attain data locality on a given host
+  private val executorsByHost = new HashMap[String, HashSet[String]]
+
+  private val executorIdToHost = new HashMap[String, String]
+
+  // Listener object to pass upcalls into
+  var dagScheduler: DAGScheduler = null
+
+  var backend: SchedulerBackend = null
+
+  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
+  var schedulableBuilder: SchedulableBuilder = null
+  var rootPool: Pool = null
+  // default scheduler is FIFO
+  val schedulingMode: SchedulingMode = SchedulingMode.withName(
+    System.getProperty("spark.scheduler.mode", "FIFO"))
+
+  // This is a var so that we can reset it for testing purposes.
+  private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
+
+  def setDAGScheduler(dagScheduler: DAGScheduler) {
+    this.dagScheduler = dagScheduler
+  }
+
+  def initialize(context: SchedulerBackend) {
+    backend = context
+    // temporarily set rootPool name to empty
+    rootPool = new Pool("", schedulingMode, 0, 0)
+    schedulableBuilder = {
+      schedulingMode match {
+        case SchedulingMode.FIFO =>
+          new FIFOSchedulableBuilder(rootPool)
+        case SchedulingMode.FAIR =>
+          new FairSchedulableBuilder(rootPool)
+      }
+    }
+    schedulableBuilder.buildPools()
+  }
+
+  def newTaskId(): Long = nextTaskId.getAndIncrement()
+
+  def start() {
+    backend.start()
+
+    if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) {
+      new Thread("TaskScheduler speculation check") {
+        setDaemon(true)
+
+        override def run() {
+          logInfo("Starting speculative execution thread")
+          while (true) {
+            try {
+              Thread.sleep(SPECULATION_INTERVAL)
+            } catch {
+              case e: InterruptedException => {}
+            }
+            checkSpeculatableTasks()
+          }
+        }
+      }.start()
+    }
+  }
+
+  def submitTasks(taskSet: TaskSet) {
+    val tasks = taskSet.tasks
+    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
+    this.synchronized {
+      val manager = new TaskSetManager(this, taskSet)
+      activeTaskSets(taskSet.id) = manager
+      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+
+      if (!isLocal && !hasReceivedTask) {
+        starvationTimer.scheduleAtFixedRate(new TimerTask() {
+          override def run() {
+            if (!hasLaunchedTask) {
+              logWarning("Initial job has not accepted any resources; " +
+                "check your cluster UI to ensure that workers are registered " +
+                "and have sufficient memory")
+            } else {
+              this.cancel()
+            }
+          }
+        }, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
+      }
+      hasReceivedTask = true
+    }
+    backend.reviveOffers()
+  }
 
-  def rootPool: Pool
+  def cancelTasks(stageId: Int): Unit = synchronized {
+    logInfo("Cancelling stage " + stageId)
+    activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+      // There are two possible cases here:
+      // 1. The task set manager has been created and some tasks have been scheduled.
+      //    In this case, send a kill signal to the executors to kill the task and then abort
+      //    the stage.
+      // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+      //    simply abort the stage.
+      val taskIds = taskSetTaskIds(tsm.taskSet.id)
+      if (taskIds.size > 0) {
+        taskIds.foreach { tid =>
+          val execId = taskIdToExecutorId(tid)
+          backend.killTask(tid, execId)
+        }
+      }
+      tsm.error("Stage %d was cancelled".format(stageId))
+    }
+  }
 
-  def schedulingMode: SchedulingMode
+  def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
+    // Check to see if the given task set has been removed. This is possible in the case of
+    // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
+    // more than one running tasks).
+    if (activeTaskSets.contains(manager.taskSet.id)) {
+      activeTaskSets -= manager.taskSet.id
+      manager.parent.removeSchedulable(manager)
+      logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+      taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
+      taskSetTaskIds.remove(manager.taskSet.id)
+    }
+  }
 
-  def start(): Unit
+  /**
+   * Called by cluster manager to offer resources on slaves. We respond by asking our active task
+   * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
+   * that tasks are balanced across the cluster.
+   */
+  def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
+    SparkEnv.set(sc.env)
 
-  // Invoked after system has successfully initialized (typically in spark context).
-  // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
+    // Mark each slave as alive and remember its hostname
+    for (o <- offers) {
+      executorIdToHost(o.executorId) = o.host
+      if (!executorsByHost.contains(o.host)) {
+        executorsByHost(o.host) = new HashSet[String]()
+        executorGained(o.executorId, o.host)
+      }
+    }
+
+    // Build a list of tasks to assign to each worker
+    val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+    val availableCpus = offers.map(o => o.cores).toArray
+    val sortedTaskSets = rootPool.getSortedTaskSetQueue()
+    for (taskSet <- sortedTaskSets) {
+      logDebug("parentName: %s, name: %s, runningTasks: %s".format(
+        taskSet.parent.name, taskSet.name, taskSet.runningTasks))
+    }
+
+    // Take each TaskSet in our scheduling order, and then offer it each node in increasing order
+    // of locality levels so that it gets a chance to launch local tasks on all of them.
+    var launchedTask = false
+    for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
+      do {
+        launchedTask = false
+        for (i <- 0 until offers.size) {
+          val execId = offers(i).executorId
+          val host = offers(i).host
+          for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
+            tasks(i) += task
+            val tid = task.taskId
+            taskIdToTaskSetId(tid) = taskSet.taskSet.id
+            taskSetTaskIds(taskSet.taskSet.id) += tid
+            taskIdToExecutorId(tid) = execId
+            activeExecutorIds += execId
+            executorsByHost(host) += execId
+            availableCpus(i) -= 1
+            launchedTask = true
+          }
+        }
+      } while (launchedTask)
+    }
+
+    if (tasks.size > 0) {
+      hasLaunchedTask = true
+    }
+    return tasks
+  }
+
+  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+    var failedExecutor: Option[String] = None
+    var taskFailed = false
+    synchronized {
+      try {
+        if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
+          // We lost this entire executor, so remember that it's gone
+          val execId = taskIdToExecutorId(tid)
+          if (activeExecutorIds.contains(execId)) {
+            removeExecutor(execId)
+            failedExecutor = Some(execId)
+          }
+        }
+        taskIdToTaskSetId.get(tid) match {
+          case Some(taskSetId) =>
+            if (TaskState.isFinished(state)) {
+              taskIdToTaskSetId.remove(tid)
+              if (taskSetTaskIds.contains(taskSetId)) {
+                taskSetTaskIds(taskSetId) -= tid
+              }
+              taskIdToExecutorId.remove(tid)
+            }
+            if (state == TaskState.FAILED) {
+              taskFailed = true
+            }
+            activeTaskSets.get(taskSetId).foreach { taskSet =>
+              if (state == TaskState.FINISHED) {
+                taskSet.removeRunningTask(tid)
+                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+                taskSet.removeRunningTask(tid)
+                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+              }
+            }
+          case None =>
+            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+        }
+      } catch {
+        case e: Exception => logError("Exception in statusUpdate", e)
+      }
+    }
+    // Update the DAGScheduler without holding a lock on this, since that can deadlock
+    if (failedExecutor != None) {
+      dagScheduler.executorLost(failedExecutor.get)
+      backend.reviveOffers()
+    }
+    if (taskFailed) {
+      // Also revive offers if a task had failed for some reason other than host lost
+      backend.reviveOffers()
+    }
+  }
+
+  def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
+    taskSetManager.handleTaskGettingResult(tid)
+  }
+
+  def handleSuccessfulTask(
+    taskSetManager: TaskSetManager,
+    tid: Long,
+    taskResult: DirectTaskResult[_]) = synchronized {
+    taskSetManager.handleSuccessfulTask(tid, taskResult)
+  }
+
+  def handleFailedTask(
+    taskSetManager: TaskSetManager,
+    tid: Long,
+    taskState: TaskState,
+    reason: Option[TaskEndReason]) = synchronized {
+    taskSetManager.handleFailedTask(tid, taskState, reason)
+    if (taskState == TaskState.FINISHED) {
+      // The task finished successfully but the result was lost, so we should revive offers.
+      backend.reviveOffers()
+    }
+  }
+
+  def error(message: String) {
+    synchronized {
+      if (activeTaskSets.size > 0) {
+        // Have each task set throw a SparkException with the error
+        for ((taskSetId, manager) <- activeTaskSets) {
+          try {
+            manager.error(message)
+          } catch {
+            case e: Exception => logError("Exception in error callback", e)
+          }
+        }
+      } else {
+        // No task sets are active but we still got an error. Just exit since this
+        // must mean the error is during registration.
+        // It might be good to do something smarter here in the future.
+        logError("Exiting due to error from task scheduler: " + message)
+        System.exit(1)
+      }
+    }
+  }
+
+  def stop() {
+    if (backend != null) {
+      backend.stop()
+    }
+    if (taskResultGetter != null) {
+      taskResultGetter.stop()
+    }
+
+    // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
+    // TODO: Do something better !
+    Thread.sleep(5000L)
+  }
+
+  def defaultParallelism() = backend.defaultParallelism()
+
+  // Check for speculatable tasks in all our active jobs.
+  def checkSpeculatableTasks() {
+    var shouldRevive = false
+    synchronized {
+      shouldRevive = rootPool.checkSpeculatableTasks()
+    }
+    if (shouldRevive) {
+      backend.reviveOffers()
+    }
+  }
+
+  // Check for pending tasks in all our active jobs.
+  def hasPendingTasks: Boolean = {
+    synchronized {
+      rootPool.hasPendingTasks()
+    }
+  }
+
+  def executorLost(executorId: String, reason: ExecutorLossReason) {
+    var failedExecutor: Option[String] = None
+
+    synchronized {
+      if (activeExecutorIds.contains(executorId)) {
+        val hostPort = executorIdToHost(executorId)
+        logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
+        removeExecutor(executorId)
+        failedExecutor = Some(executorId)
+      } else {
+         // We may get multiple executorLost() calls with different loss reasons. For example, one
+         // may be triggered by a dropped connection from the slave while another may be a report
+         // of executor termination from Mesos. We produce log messages for both so we eventually
+         // report the termination reason.
+         logError("Lost an executor " + executorId + " (already removed): " + reason)
+      }
+    }
+    // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
+    if (failedExecutor != None) {
+      dagScheduler.executorLost(failedExecutor.get)
+      backend.reviveOffers()
+    }
+  }
+
+  /** Remove an executor from all our data structures and mark it as lost */
+  private def removeExecutor(executorId: String) {
+    activeExecutorIds -= executorId
+    val host = executorIdToHost(executorId)
+    val execs = executorsByHost.getOrElse(host, new HashSet)
+    execs -= executorId
+    if (execs.isEmpty) {
+      executorsByHost -= host
+    }
+    executorIdToHost -= executorId
+    rootPool.executorLost(executorId, host)
+  }
+
+  def executorGained(execId: String, host: String) {
+    dagScheduler.executorGained(execId, host)
+  }
+
+  def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
+    executorsByHost.get(host).map(_.toSet)
+  }
+
+  def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
+    executorsByHost.contains(host)
+  }
+
+  def isExecutorAlive(execId: String): Boolean = synchronized {
+    activeExecutorIds.contains(execId)
+  }
+
+  // By default, rack is unknown
+  def getRackForHost(value: String): Option[String] = None
+
+  /**
+   * Invoked after the system has successfully been initialized. YARN uses this to bootstrap
+   * allocation of resources based on preferred locations, wait for slave registrations, etc.
+   */
   def postStartHook() { }
+}
+
 
-  // Disconnect from the cluster.
-  def stop(): Unit
+object TaskScheduler {
+  /**
+   * Used to balance containers across hosts.
+   *
+   * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
+   * resource offers representing the order in which the offers should be used.  The resource
+   * offers are ordered such that we'll allocate one container on each host before allocating a
+   * second container on any host, and so on, in order to reduce the damage if a host fails.
+   *
+   * For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns
+   * [o1, o5, o4, 02, o6, o3]
+   */
+  def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
+    val _keyList = new ArrayBuffer[K](map.size)
+    _keyList ++= map.keys
 
-  // Submit a sequence of tasks to run.
-  def submitTasks(taskSet: TaskSet): Unit
+    // order keyList based on population of value in map
+    val keyList = _keyList.sortWith(
+      (left, right) => map(left).size > map(right).size
+    )
 
-  // Cancel a stage.
-  def cancelTasks(stageId: Int)
+    val retval = new ArrayBuffer[T](keyList.size * 2)
+    var index = 0
+    var found = true
 
-  // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
-  def setDAGScheduler(dagScheduler: DAGScheduler): Unit
+    while (found) {
+      found = false
+      for (key <- keyList) {
+        val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
+        assert(containerList != null)
+        // Get the index'th entry for this host - if present
+        if (index < containerList.size){
+          retval += containerList.apply(index)
+          found = true
+        }
+      }
+      index += 1
+    }
 
-  // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
-  def defaultParallelism(): Int
+    retval.toList
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 90f6bce..13271b1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -17,32 +17,690 @@
 
 package org.apache.spark.scheduler
 
-import java.nio.ByteBuffer
+import java.util.Arrays
 
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
+  Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
 import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler._
+import org.apache.spark.util.{SystemClock, Clock}
+
 
 /**
- * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
- * each task and is responsible for retries on failure and locality. The main interfaces to it
- * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
- * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ * Schedules the tasks within a single TaskSet in the TaskScheduler. This class keeps track of
+ * each task, retries tasks if they fail (up to a limited number of times), and
+ * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
+ * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
+ * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
  *
- * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
- * (e.g. its event handlers). It should not be called from other threads.
+ * THREADING: This class is designed to only be called from code with a lock on the
+ * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
  */
-private[spark] trait TaskSetManager extends Schedulable {
-  def schedulableQueue = null
-  
-  def schedulingMode = SchedulingMode.NONE
-  
-  def taskSet: TaskSet
+private[spark] class TaskSetManager(
+    sched: TaskScheduler,
+    val taskSet: TaskSet,
+    clock: Clock = SystemClock)
+  extends Schedulable with Logging
+{
+  // CPUs to request per task
+  val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+
+  // Maximum times a task is allowed to fail before failing the job
+  val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt
+
+  // Quantile of tasks at which to start speculation
+  val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+  val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+  // Serializer for closures and tasks.
+  val env = SparkEnv.get
+  val ser = env.closureSerializer.newInstance()
+
+  val tasks = taskSet.tasks
+  val numTasks = tasks.length
+  val copiesRunning = new Array[Int](numTasks)
+  val successful = new Array[Boolean](numTasks)
+  val numFailures = new Array[Int](numTasks)
+  val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+  var tasksSuccessful = 0
+
+  var weight = 1
+  var minShare = 0
+  var priority = taskSet.priority
+  var stageId = taskSet.stageId
+  var name = "TaskSet_"+taskSet.stageId.toString
+  var parent: Pool = null
+
+  var runningTasks = 0
+  private val runningTasksSet = new HashSet[Long]
+
+  // Set of pending tasks for each executor. These collections are actually
+  // treated as stacks, in which new tasks are added to the end of the
+  // ArrayBuffer and removed from the end. This makes it faster to detect
+  // tasks that repeatedly fail because whenever a task failed, it is put
+  // back at the head of the stack. They are also only cleaned up lazily;
+  // when a task is launched, it remains in all the pending lists except
+  // the one that it was launched from, but gets removed from them later.
+  private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
+
+  // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+  // but at host level.
+  private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+  // Set of pending tasks for each rack -- similar to the above.
+  private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
+
+  // Set containing pending tasks with no locality preferences.
+  val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+  // Set containing all pending tasks (also used as a stack, as above).
+  val allPendingTasks = new ArrayBuffer[Int]
+
+  // Tasks that can be speculated. Since these will be a small fraction of total
+  // tasks, we'll just hold them in a HashSet.
+  val speculatableTasks = new HashSet[Int]
+
+  // Task index, start and finish time for each task attempt (indexed by task ID)
+  val taskInfos = new HashMap[Long, TaskInfo]
+
+  // Did the TaskSet fail?
+  var failed = false
+  var causeOfFailure = ""
+
+  // How frequently to reprint duplicate exceptions in full, in milliseconds
+  val EXCEPTION_PRINT_INTERVAL =
+    System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+
+  // Map of recent exceptions (identified by string representation and top stack frame) to
+  // duplicate count (how many times the same exception has appeared) and time the full exception
+  // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
+  val recentExceptions = HashMap[String, (Int, Long)]()
+
+  // Figure out the current map output tracker epoch and set it on all tasks
+  val epoch = sched.mapOutputTracker.getEpoch
+  logDebug("Epoch for " + taskSet + ": " + epoch)
+  for (t <- tasks) {
+    t.epoch = epoch
+  }
+
+  // Add all our tasks to the pending lists. We do this in reverse order
+  // of task index so that tasks with low indices get launched first.
+  for (i <- (0 until numTasks).reverse) {
+    addPendingTask(i)
+  }
+
+  // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
+  val myLocalityLevels = computeValidLocalityLevels()
+  val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+
+  // Delay scheduling variables: we keep track of our current locality level and the time we
+  // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
+  // We then move down if we manage to launch a "more local" task.
+  var currentLocalityIndex = 0    // Index of our current locality level in validLocalityLevels
+  var lastLaunchTime = clock.getTime()  // Time we last launched a task at this level
+
+  override def schedulableQueue = null
+
+  override def schedulingMode = SchedulingMode.NONE
+
+  /**
+   * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+   * re-adding the task so only include it in each list if it's not already there.
+   */
+  private def addPendingTask(index: Int, readding: Boolean = false) {
+    // Utility method that adds `index` to a list only if readding=false or it's not already there
+    def addTo(list: ArrayBuffer[Int]) {
+      if (!readding || !list.contains(index)) {
+        list += index
+      }
+    }
+
+    var hadAliveLocations = false
+    for (loc <- tasks(index).preferredLocations) {
+      for (execId <- loc.executorId) {
+        if (sched.isExecutorAlive(execId)) {
+          addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+          hadAliveLocations = true
+        }
+      }
+      if (sched.hasExecutorsAliveOnHost(loc.host)) {
+        addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+        for (rack <- sched.getRackForHost(loc.host)) {
+          addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+        }
+        hadAliveLocations = true
+      }
+    }
+
+    if (!hadAliveLocations) {
+      // Even though the task might've had preferred locations, all of those hosts or executors
+      // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+      addTo(pendingTasksWithNoPrefs)
+    }
+
+    if (!readding) {
+      allPendingTasks += index  // No point scanning this whole list to find the old task there
+    }
+  }
+
+  /**
+   * Return the pending tasks list for a given executor ID, or an empty list if
+   * there is no map entry for that host
+   */
+  private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+    pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
+  }
+
+  /**
+   * Return the pending tasks list for a given host, or an empty list if
+   * there is no map entry for that host
+   */
+  private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+    pendingTasksForHost.getOrElse(host, ArrayBuffer())
+  }
+
+  /**
+   * Return the pending rack-local task list for a given rack, or an empty list if
+   * there is no map entry for that rack
+   */
+  private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+    pendingTasksForRack.getOrElse(rack, ArrayBuffer())
+  }
+
+  /**
+   * Dequeue a pending task from the given list and return its index.
+   * Return None if the list is empty.
+   * This method also cleans up any tasks in the list that have already
+   * been launched, since we want that to happen lazily.
+   */
+  private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+    while (!list.isEmpty) {
+      val index = list.last
+      list.trimEnd(1)
+      if (copiesRunning(index) == 0 && !successful(index)) {
+        return Some(index)
+      }
+    }
+    return None
+  }
+
+  /** Check whether a task is currently running an attempt on a given host */
+  private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+    !taskAttempts(taskIndex).exists(_.host == host)
+  }
+
+  /**
+   * Return a speculative task for a given executor if any are available. The task should not have
+   * an attempt running on this host, in case the host is slow. In addition, the task should meet
+   * the given locality constraint.
+   */
+  private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+    : Option[(Int, TaskLocality.Value)] =
+  {
+    speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
+
+    if (!speculatableTasks.isEmpty) {
+      // Check for process-local or preference-less tasks; note that tasks can be process-local
+      // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+      for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+        val prefs = tasks(index).preferredLocations
+        val executors = prefs.flatMap(_.executorId)
+        if (prefs.size == 0 || executors.contains(execId)) {
+          speculatableTasks -= index
+          return Some((index, TaskLocality.PROCESS_LOCAL))
+        }
+      }
+
+      // Check for node-local tasks
+      if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+        for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+          val locations = tasks(index).preferredLocations.map(_.host)
+          if (locations.contains(host)) {
+            speculatableTasks -= index
+            return Some((index, TaskLocality.NODE_LOCAL))
+          }
+        }
+      }
+
+      // Check for rack-local tasks
+      if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+        for (rack <- sched.getRackForHost(host)) {
+          for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+            val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+            if (racks.contains(rack)) {
+              speculatableTasks -= index
+              return Some((index, TaskLocality.RACK_LOCAL))
+            }
+          }
+        }
+      }
 
+      // Check for non-local tasks
+      if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+        for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+          speculatableTasks -= index
+          return Some((index, TaskLocality.ANY))
+        }
+      }
+    }
+
+    return None
+  }
+
+  /**
+   * Dequeue a pending task for a given node and return its index and locality level.
+   * Only search for tasks matching the given locality constraint.
+   */
+  private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+    : Option[(Int, TaskLocality.Value)] =
+  {
+    for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+      return Some((index, TaskLocality.PROCESS_LOCAL))
+    }
+
+    if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+      for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+        return Some((index, TaskLocality.NODE_LOCAL))
+      }
+    }
+
+    if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+      for {
+        rack <- sched.getRackForHost(host)
+        index <- findTaskFromList(getPendingTasksForRack(rack))
+      } {
+        return Some((index, TaskLocality.RACK_LOCAL))
+      }
+    }
+
+    // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+    for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+      return Some((index, TaskLocality.PROCESS_LOCAL))
+    }
+
+    if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+      for (index <- findTaskFromList(allPendingTasks)) {
+        return Some((index, TaskLocality.ANY))
+      }
+    }
+
+    // Finally, if all else has failed, find a speculative task
+    return findSpeculativeTask(execId, host, locality)
+  }
+
+  /**
+   * Respond to an offer of a single executor from the scheduler by finding a task
+   */
   def resourceOffer(
       execId: String,
       host: String,
       availableCpus: Int,
       maxLocality: TaskLocality.TaskLocality)
-    : Option[TaskDescription]
+    : Option[TaskDescription] =
+  {
+    if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+      val curTime = clock.getTime()
+
+      var allowedLocality = getAllowedLocalityLevel(curTime)
+      if (allowedLocality > maxLocality) {
+        allowedLocality = maxLocality   // We're not allowed to search for farther-away tasks
+      }
+
+      findTask(execId, host, allowedLocality) match {
+        case Some((index, taskLocality)) => {
+          // Found a task; do some bookkeeping and return a task description
+          val task = tasks(index)
+          val taskId = sched.newTaskId()
+          // Figure out whether this should count as a preferred launch
+          logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+            taskSet.id, index, taskId, execId, host, taskLocality))
+          // Do various bookkeeping
+          copiesRunning(index) += 1
+          val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+          taskInfos(taskId) = info
+          taskAttempts(index) = info :: taskAttempts(index)
+          // Update our locality level for delay scheduling
+          currentLocalityIndex = getLocalityIndex(taskLocality)
+          lastLaunchTime = curTime
+          // Serialize and return the task
+          val startTime = clock.getTime()
+          // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+          // we assume the task can be serialized without exceptions.
+          val serializedTask = Task.serializeWithDependencies(
+            task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+          val timeTaken = clock.getTime() - startTime
+          addRunningTask(taskId)
+          logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+            taskSet.id, index, serializedTask.limit, timeTaken))
+          val taskName = "task %s:%d".format(taskSet.id, index)
+          if (taskAttempts(index).size == 1)
+            taskStarted(task,info)
+          return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+        }
+        case _ =>
+      }
+    }
+    return None
+  }
+
+  /**
+   * Get the level we can launch tasks according to delay scheduling, based on current wait time.
+   */
+  private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
+    while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
+        currentLocalityIndex < myLocalityLevels.length - 1)
+    {
+      // Jump to the next locality level, and remove our waiting time for the current one since
+      // we don't want to count it again on the next one
+      lastLaunchTime += localityWaits(currentLocalityIndex)
+      currentLocalityIndex += 1
+    }
+    myLocalityLevels(currentLocalityIndex)
+  }
+
+  /**
+   * Find the index in myLocalityLevels for a given locality. This is also designed to work with
+   * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
+   * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
+   */
+  def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
+    var index = 0
+    while (locality > myLocalityLevels(index)) {
+      index += 1
+    }
+    index
+  }
+
+  private def taskStarted(task: Task[_], info: TaskInfo) {
+    sched.dagScheduler.taskStarted(task, info)
+  }
+
+  def handleTaskGettingResult(tid: Long) = {
+    val info = taskInfos(tid)
+    info.markGettingResult()
+    sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+  }
+
+  /**
+   * Marks the task as successful and notifies the DAGScheduler that a task has ended.
+   */
+  def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
+    val info = taskInfos(tid)
+    val index = info.index
+    info.markSuccessful()
+    removeRunningTask(tid)
+    if (!successful(index)) {
+      logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
+        tid, info.duration, info.host, tasksSuccessful, numTasks))
+      sched.dagScheduler.taskEnded(
+        tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+      // Mark successful and stop if all the tasks have succeeded.
+      tasksSuccessful += 1
+      successful(index) = true
+      if (tasksSuccessful == numTasks) {
+        sched.taskSetFinished(this)
+      }
+    } else {
+      logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+        index + " has already completed successfully")
+    }
+  }
+
+  /**
+   * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+   * DAG Scheduler.
+   */
+  def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+    val info = taskInfos(tid)
+    if (info.failed) {
+      return
+    }
+    removeRunningTask(tid)
+    val index = info.index
+    info.markFailed()
+    if (!successful(index)) {
+      logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+      copiesRunning(index) -= 1
+      // Check if the problem is a map output fetch failure. In that case, this
+      // task will never succeed on any node, so tell the scheduler about it.
+      reason.foreach {
+        case fetchFailed: FetchFailed =>
+          logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+          sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+          successful(index) = true
+          tasksSuccessful += 1
+          sched.taskSetFinished(this)
+          removeAllRunningTasks()
+          return
+
+        case TaskKilled =>
+          logWarning("Task %d was killed.".format(tid))
+          sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+          return
+
+        case ef: ExceptionFailure =>
+          sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+          val key = ef.description
+          val now = clock.getTime()
+          val (printFull, dupCount) = {
+            if (recentExceptions.contains(key)) {
+              val (dupCount, printTime) = recentExceptions(key)
+              if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+                recentExceptions(key) = (0, now)
+                (true, 0)
+              } else {
+                recentExceptions(key) = (dupCount + 1, printTime)
+                (false, dupCount + 1)
+              }
+            } else {
+              recentExceptions(key) = (0, now)
+              (true, 0)
+            }
+          }
+          if (printFull) {
+            val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+            logWarning("Loss was due to %s\n%s\n%s".format(
+              ef.className, ef.description, locs.mkString("\n")))
+          } else {
+            logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+          }
+
+        case TaskResultLost =>
+          logWarning("Lost result for TID %s on host %s".format(tid, info.host))
+          sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
+        case _ => {}
+      }
+      // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+      addPendingTask(index)
+      if (state != TaskState.KILLED) {
+        numFailures(index) += 1
+        if (numFailures(index) > MAX_TASK_FAILURES) {
+          logError("Task %s:%d failed more than %d times; aborting job".format(
+            taskSet.id, index, MAX_TASK_FAILURES))
+          abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+        }
+      }
+    } else {
+      logInfo("Ignoring task-lost event for TID " + tid +
+        " because task " + index + " is already finished")
+    }
+  }
+
+  def error(message: String) {
+    // Save the error message
+    abort("Error: " + message)
+  }
+
+  def abort(message: String) {
+    failed = true
+    causeOfFailure = message
+    // TODO: Kill running tasks if we were not terminated due to a Mesos error
+    sched.dagScheduler.taskSetFailed(taskSet, message)
+    removeAllRunningTasks()
+    sched.taskSetFinished(this)
+  }
+
+  /** If the given task ID is not in the set of running tasks, adds it.
+   *
+   * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+   */
+  def addRunningTask(tid: Long) {
+    if (runningTasksSet.add(tid) && parent != null) {
+      parent.increaseRunningTasks(1)
+    }
+    runningTasks = runningTasksSet.size
+  }
+
+  /** If the given task ID is in the set of running tasks, removes it. */
+  def removeRunningTask(tid: Long) {
+    if (runningTasksSet.remove(tid) && parent != null) {
+      parent.decreaseRunningTasks(1)
+    }
+    runningTasks = runningTasksSet.size
+  }
+
+  private def removeAllRunningTasks() {
+    val numRunningTasks = runningTasksSet.size
+    runningTasksSet.clear()
+    if (parent != null) {
+      parent.decreaseRunningTasks(numRunningTasks)
+    }
+    runningTasks = 0
+  }
+
+  override def getSchedulableByName(name: String): Schedulable = {
+    return null
+  }
+
+  override def addSchedulable(schedulable: Schedulable) {}
+
+  override def removeSchedulable(schedulable: Schedulable) {}
+
+  override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+    var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
+    sortedTaskSetQueue += this
+    return sortedTaskSetQueue
+  }
+
+  /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
+  override def executorLost(execId: String, host: String) {
+    logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+    // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+    // task that used to have locations on only this host might now go to the no-prefs list. Note
+    // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+    // locations), because findTaskFromList will skip already-running tasks.
+    for (index <- getPendingTasksForExecutor(execId)) {
+      addPendingTask(index, readding=true)
+    }
+    for (index <- getPendingTasksForHost(host)) {
+      addPendingTask(index, readding=true)
+    }
+
+    // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+    if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+      for ((tid, info) <- taskInfos if info.executorId == execId) {
+        val index = taskInfos(tid).index
+        if (successful(index)) {
+          successful(index) = false
+          copiesRunning(index) -= 1
+          tasksSuccessful -= 1
+          addPendingTask(index)
+          // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+          // stage finishes when a total of tasks.size tasks finish.
+          sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+        }
+      }
+    }
+    // Also re-enqueue any tasks that were running on the node
+    for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+      handleFailedTask(tid, TaskState.KILLED, None)
+    }
+  }
+
+  /**
+   * Check for tasks to be speculated and return true if there are any. This is called periodically
+   * by the TaskScheduler.
+   *
+   * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+   * we don't scan the whole task set. It might also help to make this sorted by launch time.
+   */
+  override def checkSpeculatableTasks(): Boolean = {
+    // Can't speculate if we only have one task, or if all tasks have finished.
+    if (numTasks == 1 || tasksSuccessful == numTasks) {
+      return false
+    }
+    var foundTasks = false
+    val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+    logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+    if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
+      val time = clock.getTime()
+      val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+      Arrays.sort(durations)
+      val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
+      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+      // TODO: Threshold should also look at standard deviation of task durations and have a lower
+      // bound based on that.
+      logDebug("Task length threshold for speculation: " + threshold)
+      for ((tid, info) <- taskInfos) {
+        val index = info.index
+        if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+          !speculatableTasks.contains(index)) {
+          logInfo(
+            "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+              taskSet.id, index, info.host, threshold))
+          speculatableTasks += index
+          foundTasks = true
+        }
+      }
+    }
+    return foundTasks
+  }
+
+  override def hasPendingTasks(): Boolean = {
+    numTasks > 0 && tasksSuccessful < numTasks
+  }
+
+  private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
+    val defaultWait = System.getProperty("spark.locality.wait", "3000")
+    level match {
+      case TaskLocality.PROCESS_LOCAL =>
+        System.getProperty("spark.locality.wait.process", defaultWait).toLong
+      case TaskLocality.NODE_LOCAL =>
+        System.getProperty("spark.locality.wait.node", defaultWait).toLong
+      case TaskLocality.RACK_LOCAL =>
+        System.getProperty("spark.locality.wait.rack", defaultWait).toLong
+      case TaskLocality.ANY =>
+        0L
+    }
+  }
 
-  def error(message: String)
+  /**
+   * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
+   * added to queues using addPendingTask.
+   */
+  private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
+    import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
+    val levels = new ArrayBuffer[TaskLocality.TaskLocality]
+    if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
+      levels += PROCESS_LOCAL
+    }
+    if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
+      levels += NODE_LOCAL
+    }
+    if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
+      levels += RACK_LOCAL
+    }
+    levels += ANY
+    logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
+    levels.toArray
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
new file mode 100644
index 0000000..ba6bab3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+/**
+ * Represents free resources available on an executor.
+ */
+private[spark]
+class WorkerOffer(val executorId: String, val host: String, val cores: Int)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5e91495f/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
deleted file mode 100644
index 8503395..0000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ /dev/null
@@ -1,486 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicLong
-import java.util.{TimerTask, Timer}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-
-/**
- * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
- * initialize() and start(), then submit task sets through the runTasks method.
- *
- * This class can work with multiple types of clusters by acting through a SchedulerBackend.
- * It handles common logic, like determining a scheduling order across jobs, waking up to launch
- * speculative tasks, etc.
- *
- * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
- * threads, so it needs locks in public API methods to maintain its state. In addition, some
- * SchedulerBackends sycnchronize on themselves when they want to send events here, and then
- * acquire a lock on us, so we need to make sure that we don't try to lock the backend while
- * we are holding a lock on ourselves.
- */
-private[spark] class ClusterScheduler(val sc: SparkContext)
-  extends TaskScheduler
-  with Logging
-{
-  // How often to check for speculative tasks
-  val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
-
-  // Threshold above which we warn user initial TaskSet may be starved
-  val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
-
-  // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
-  // on this class.
-  val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
-
-  val taskIdToTaskSetId = new HashMap[Long, String]
-  val taskIdToExecutorId = new HashMap[Long, String]
-  val taskSetTaskIds = new HashMap[String, HashSet[Long]]
-
-  @volatile private var hasReceivedTask = false
-  @volatile private var hasLaunchedTask = false
-  private val starvationTimer = new Timer(true)
-
-  // Incrementing task IDs
-  val nextTaskId = new AtomicLong(0)
-
-  // Which executor IDs we have executors on
-  val activeExecutorIds = new HashSet[String]
-
-  // The set of executors we have on each host; this is used to compute hostsAlive, which
-  // in turn is used to decide when we can attain data locality on a given host
-  private val executorsByHost = new HashMap[String, HashSet[String]]
-
-  private val executorIdToHost = new HashMap[String, String]
-
-  // Listener object to pass upcalls into
-  var dagScheduler: DAGScheduler = null
-
-  var backend: SchedulerBackend = null
-
-  val mapOutputTracker = SparkEnv.get.mapOutputTracker
-
-  var schedulableBuilder: SchedulableBuilder = null
-  var rootPool: Pool = null
-  // default scheduler is FIFO
-  val schedulingMode: SchedulingMode = SchedulingMode.withName(
-    System.getProperty("spark.scheduler.mode", "FIFO"))
-
-  // This is a var so that we can reset it for testing purposes.
-  private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
-
-  override def setDAGScheduler(dagScheduler: DAGScheduler) {
-    this.dagScheduler = dagScheduler
-  }
-
-  def initialize(context: SchedulerBackend) {
-    backend = context
-    // temporarily set rootPool name to empty
-    rootPool = new Pool("", schedulingMode, 0, 0)
-    schedulableBuilder = {
-      schedulingMode match {
-        case SchedulingMode.FIFO =>
-          new FIFOSchedulableBuilder(rootPool)
-        case SchedulingMode.FAIR =>
-          new FairSchedulableBuilder(rootPool)
-      }
-    }
-    schedulableBuilder.buildPools()
-  }
-
-  def newTaskId(): Long = nextTaskId.getAndIncrement()
-
-  override def start() {
-    backend.start()
-
-    if (System.getProperty("spark.speculation", "false").toBoolean) {
-      new Thread("ClusterScheduler speculation check") {
-        setDaemon(true)
-
-        override def run() {
-          logInfo("Starting speculative execution thread")
-          while (true) {
-            try {
-              Thread.sleep(SPECULATION_INTERVAL)
-            } catch {
-              case e: InterruptedException => {}
-            }
-            checkSpeculatableTasks()
-          }
-        }
-      }.start()
-    }
-  }
-
-  override def submitTasks(taskSet: TaskSet) {
-    val tasks = taskSet.tasks
-    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
-    this.synchronized {
-      val manager = new ClusterTaskSetManager(this, taskSet)
-      activeTaskSets(taskSet.id) = manager
-      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
-      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
-
-      if (!hasReceivedTask) {
-        starvationTimer.scheduleAtFixedRate(new TimerTask() {
-          override def run() {
-            if (!hasLaunchedTask) {
-              logWarning("Initial job has not accepted any resources; " +
-                "check your cluster UI to ensure that workers are registered " +
-                "and have sufficient memory")
-            } else {
-              this.cancel()
-            }
-          }
-        }, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
-      }
-      hasReceivedTask = true
-    }
-    backend.reviveOffers()
-  }
-
-  override def cancelTasks(stageId: Int): Unit = synchronized {
-    logInfo("Cancelling stage " + stageId)
-    activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
-      // There are two possible cases here:
-      // 1. The task set manager has been created and some tasks have been scheduled.
-      //    In this case, send a kill signal to the executors to kill the task and then abort
-      //    the stage.
-      // 2. The task set manager has been created but no tasks has been scheduled. In this case,
-      //    simply abort the stage.
-      val taskIds = taskSetTaskIds(tsm.taskSet.id)
-      if (taskIds.size > 0) {
-        taskIds.foreach { tid =>
-          val execId = taskIdToExecutorId(tid)
-          backend.killTask(tid, execId)
-        }
-      }
-      tsm.error("Stage %d was cancelled".format(stageId))
-    }
-  }
-
-  def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
-    // Check to see if the given task set has been removed. This is possible in the case of
-    // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
-    // more than one running tasks).
-    if (activeTaskSets.contains(manager.taskSet.id)) {
-      activeTaskSets -= manager.taskSet.id
-      manager.parent.removeSchedulable(manager)
-      logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
-      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
-      taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
-      taskSetTaskIds.remove(manager.taskSet.id)
-    }
-  }
-
-  /**
-   * Called by cluster manager to offer resources on slaves. We respond by asking our active task
-   * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
-   * that tasks are balanced across the cluster.
-   */
-  def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
-    SparkEnv.set(sc.env)
-
-    // Mark each slave as alive and remember its hostname
-    for (o <- offers) {
-      executorIdToHost(o.executorId) = o.host
-      if (!executorsByHost.contains(o.host)) {
-        executorsByHost(o.host) = new HashSet[String]()
-        executorGained(o.executorId, o.host)
-      }
-    }
-
-    // Build a list of tasks to assign to each worker
-    val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
-    val availableCpus = offers.map(o => o.cores).toArray
-    val sortedTaskSets = rootPool.getSortedTaskSetQueue()
-    for (taskSet <- sortedTaskSets) {
-      logDebug("parentName: %s, name: %s, runningTasks: %s".format(
-        taskSet.parent.name, taskSet.name, taskSet.runningTasks))
-    }
-
-    // Take each TaskSet in our scheduling order, and then offer it each node in increasing order
-    // of locality levels so that it gets a chance to launch local tasks on all of them.
-    var launchedTask = false
-    for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
-      do {
-        launchedTask = false
-        for (i <- 0 until offers.size) {
-          val execId = offers(i).executorId
-          val host = offers(i).host
-          for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
-            tasks(i) += task
-            val tid = task.taskId
-            taskIdToTaskSetId(tid) = taskSet.taskSet.id
-            taskSetTaskIds(taskSet.taskSet.id) += tid
-            taskIdToExecutorId(tid) = execId
-            activeExecutorIds += execId
-            executorsByHost(host) += execId
-            availableCpus(i) -= 1
-            launchedTask = true
-          }
-        }
-      } while (launchedTask)
-    }
-
-    if (tasks.size > 0) {
-      hasLaunchedTask = true
-    }
-    return tasks
-  }
-
-  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
-    var failedExecutor: Option[String] = None
-    var taskFailed = false
-    synchronized {
-      try {
-        if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
-          // We lost this entire executor, so remember that it's gone
-          val execId = taskIdToExecutorId(tid)
-          if (activeExecutorIds.contains(execId)) {
-            removeExecutor(execId)
-            failedExecutor = Some(execId)
-          }
-        }
-        taskIdToTaskSetId.get(tid) match {
-          case Some(taskSetId) =>
-            if (TaskState.isFinished(state)) {
-              taskIdToTaskSetId.remove(tid)
-              if (taskSetTaskIds.contains(taskSetId)) {
-                taskSetTaskIds(taskSetId) -= tid
-              }
-              taskIdToExecutorId.remove(tid)
-            }
-            if (state == TaskState.FAILED) {
-              taskFailed = true
-            }
-            activeTaskSets.get(taskSetId).foreach { taskSet =>
-              if (state == TaskState.FINISHED) {
-                taskSet.removeRunningTask(tid)
-                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
-              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
-                taskSet.removeRunningTask(tid)
-                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
-              }
-            }
-          case None =>
-            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
-        }
-      } catch {
-        case e: Exception => logError("Exception in statusUpdate", e)
-      }
-    }
-    // Update the DAGScheduler without holding a lock on this, since that can deadlock
-    if (failedExecutor != None) {
-      dagScheduler.executorLost(failedExecutor.get)
-      backend.reviveOffers()
-    }
-    if (taskFailed) {
-      // Also revive offers if a task had failed for some reason other than host lost
-      backend.reviveOffers()
-    }
-  }
-
-  def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
-    taskSetManager.handleTaskGettingResult(tid)
-  }
-
-  def handleSuccessfulTask(
-    taskSetManager: ClusterTaskSetManager,
-    tid: Long,
-    taskResult: DirectTaskResult[_]) = synchronized {
-    taskSetManager.handleSuccessfulTask(tid, taskResult)
-  }
-
-  def handleFailedTask(
-    taskSetManager: ClusterTaskSetManager,
-    tid: Long,
-    taskState: TaskState,
-    reason: Option[TaskEndReason]) = synchronized {
-    taskSetManager.handleFailedTask(tid, taskState, reason)
-    if (taskState == TaskState.FINISHED) {
-      // The task finished successfully but the result was lost, so we should revive offers.
-      backend.reviveOffers()
-    }
-  }
-
-  def error(message: String) {
-    synchronized {
-      if (activeTaskSets.size > 0) {
-        // Have each task set throw a SparkException with the error
-        for ((taskSetId, manager) <- activeTaskSets) {
-          try {
-            manager.error(message)
-          } catch {
-            case e: Exception => logError("Exception in error callback", e)
-          }
-        }
-      } else {
-        // No task sets are active but we still got an error. Just exit since this
-        // must mean the error is during registration.
-        // It might be good to do something smarter here in the future.
-        logError("Exiting due to error from cluster scheduler: " + message)
-        System.exit(1)
-      }
-    }
-  }
-
-  override def stop() {
-    if (backend != null) {
-      backend.stop()
-    }
-    if (taskResultGetter != null) {
-      taskResultGetter.stop()
-    }
-
-    // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
-    // TODO: Do something better !
-    Thread.sleep(5000L)
-  }
-
-  override def defaultParallelism() = backend.defaultParallelism()
-
-
-  // Check for speculatable tasks in all our active jobs.
-  def checkSpeculatableTasks() {
-    var shouldRevive = false
-    synchronized {
-      shouldRevive = rootPool.checkSpeculatableTasks()
-    }
-    if (shouldRevive) {
-      backend.reviveOffers()
-    }
-  }
-
-  // Check for pending tasks in all our active jobs.
-  def hasPendingTasks: Boolean = {
-    synchronized {
-      rootPool.hasPendingTasks()
-    }
-  }
-
-  def executorLost(executorId: String, reason: ExecutorLossReason) {
-    var failedExecutor: Option[String] = None
-
-    synchronized {
-      if (activeExecutorIds.contains(executorId)) {
-        val hostPort = executorIdToHost(executorId)
-        logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
-        removeExecutor(executorId)
-        failedExecutor = Some(executorId)
-      } else {
-         // We may get multiple executorLost() calls with different loss reasons. For example, one
-         // may be triggered by a dropped connection from the slave while another may be a report
-         // of executor termination from Mesos. We produce log messages for both so we eventually
-         // report the termination reason.
-         logError("Lost an executor " + executorId + " (already removed): " + reason)
-      }
-    }
-    // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
-    if (failedExecutor != None) {
-      dagScheduler.executorLost(failedExecutor.get)
-      backend.reviveOffers()
-    }
-  }
-
-  /** Remove an executor from all our data structures and mark it as lost */
-  private def removeExecutor(executorId: String) {
-    activeExecutorIds -= executorId
-    val host = executorIdToHost(executorId)
-    val execs = executorsByHost.getOrElse(host, new HashSet)
-    execs -= executorId
-    if (execs.isEmpty) {
-      executorsByHost -= host
-    }
-    executorIdToHost -= executorId
-    rootPool.executorLost(executorId, host)
-  }
-
-  def executorGained(execId: String, host: String) {
-    dagScheduler.executorGained(execId, host)
-  }
-
-  def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
-    executorsByHost.get(host).map(_.toSet)
-  }
-
-  def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
-    executorsByHost.contains(host)
-  }
-
-  def isExecutorAlive(execId: String): Boolean = synchronized {
-    activeExecutorIds.contains(execId)
-  }
-
-  // By default, rack is unknown
-  def getRackForHost(value: String): Option[String] = None
-}
-
-
-object ClusterScheduler {
-  /**
-   * Used to balance containers across hosts.
-   *
-   * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
-   * resource offers representing the order in which the offers should be used.  The resource
-   * offers are ordered such that we'll allocate one container on each host before allocating a
-   * second container on any host, and so on, in order to reduce the damage if a host fails.
-   *
-   * For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns
-   * [o1, o5, o4, 02, o6, o3]
-   */
-  def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
-    val _keyList = new ArrayBuffer[K](map.size)
-    _keyList ++= map.keys
-
-    // order keyList based on population of value in map
-    val keyList = _keyList.sortWith(
-      (left, right) => map(left).size > map(right).size
-    )
-
-    val retval = new ArrayBuffer[T](keyList.size * 2)
-    var index = 0
-    var found = true
-
-    while (found) {
-      found = false
-      for (key <- keyList) {
-        val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
-        assert(containerList != null)
-        // Get the index'th entry for this host - if present
-        if (index < containerList.size){
-          retval += containerList.apply(index)
-          found = true
-        }
-      }
-      index += 1
-    }
-
-    retval.toList
-  }
-}


Mime
View raw message