spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-18761][CORE] Introduce "task reaper" to oversee task killing in executors
Date Tue, 20 Dec 2016 02:44:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5857b9ac2 -> fa829ce21


[SPARK-18761][CORE] Introduce "task reaper" to oversee task killing in executors

## What changes were proposed in this pull request?

Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
may not be interruptible or may not respond to their "killed" flags being set. If a significant
fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
remain running then this can lead to a situation where new jobs and tasks are starved of resources
that are being used by these zombie tasks.

This patch aims to address this problem by adding a "task reaper" mechanism to executors.
At a high-level, task killing now launches a new thread which attempts to kill the task and
then watches the task and periodically checks whether it has been killed. The TaskReaper will
periodically re-attempt to call `TaskRunner.kill()` and will log warnings if the task keeps
running. I modified TaskRunner to rename its thread at the start of the task, allowing TaskReaper
to take a thread dump and filter it in order to log stacktraces from the exact task thread
that we are waiting to finish. If the task has not stopped after a configurable timeout then
the TaskReaper will throw an exception to trigger executor JVM death, thereby forcibly freeing
any resources consumed by the zombie tasks.

This feature is flagged off by default and is controlled by four new configurations under
the `spark.task.reaper.*` namespace. See the updated `configuration.md` doc for details.

## How was this patch tested?

Tested via a new test case in `JobCancellationSuite`, plus manual testing.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #16189 from JoshRosen/cancellation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fa829ce2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fa829ce2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fa829ce2

Branch: refs/heads/master
Commit: fa829ce21fb84028d90b739a49c4ece70a17ccfd
Parents: 5857b9a
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Mon Dec 19 18:43:59 2016 -0800
Committer: Yin Huai <yhuai@databricks.com>
Committed: Mon Dec 19 18:43:59 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    | 169 ++++++++++++++++++-
 .../scala/org/apache/spark/util/Utils.scala     |  56 +++---
 .../org/apache/spark/JobCancellationSuite.scala |  77 +++++++++
 docs/configuration.md                           |  42 +++++
 4 files changed, 316 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fa829ce2/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9501dd9..3346f6d 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -84,6 +84,16 @@ private[spark] class Executor(
   // Start worker thread pool
   private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
   private val executorSource = new ExecutorSource(threadPool, executorId)
+  // Pool used for threads that supervise task killing / cancellation
+  private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
+  // For tasks which are in the process of being killed, this map holds the most recently
created
+  // TaskReaper. All accesses to this map should be synchronized on the map itself (this
isn't
+  // a ConcurrentHashMap because we use the synchronization for purposes other than simply
guarding
+  // the integrity of the map's internal state). The purpose of this map is to prevent the
creation
+  // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows
us to
+  // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would
otherwise
+  // create. The map key is a task id.
+  private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()
 
   if (!isLocal) {
     env.metricsSystem.registerSource(executorSource)
@@ -93,6 +103,9 @@ private[spark] class Executor(
   // Whether to load classes in user jars before those in Spark jars
   private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false)
 
+  // Whether to monitor killed / interrupted tasks
+  private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false)
+
   // Create our ClassLoader
   // do this after SparkEnv creation so can access the SecurityManager
   private val urlClassLoader = createClassLoader()
@@ -148,9 +161,27 @@ private[spark] class Executor(
   }
 
   def killTask(taskId: Long, interruptThread: Boolean): Unit = {
-    val tr = runningTasks.get(taskId)
-    if (tr != null) {
-      tr.kill(interruptThread)
+    val taskRunner = runningTasks.get(taskId)
+    if (taskRunner != null) {
+      if (taskReaperEnabled) {
+        val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
+          val shouldCreateReaper = taskReaperForTask.get(taskId) match {
+            case None => true
+            case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
+          }
+          if (shouldCreateReaper) {
+            val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
+            taskReaperForTask(taskId) = taskReaper
+            Some(taskReaper)
+          } else {
+            None
+          }
+        }
+        // Execute the TaskReaper from outside of the synchronized block.
+        maybeNewTaskReaper.foreach(taskReaperPool.execute)
+      } else {
+        taskRunner.kill(interruptThread = interruptThread)
+      }
     }
   }
 
@@ -161,12 +192,7 @@ private[spark] class Executor(
    * @param interruptThread whether to interrupt the task thread
    */
   def killAllTasks(interruptThread: Boolean) : Unit = {
-    // kill all the running tasks
-    for (taskRunner <- runningTasks.values().asScala) {
-      if (taskRunner != null) {
-        taskRunner.kill(interruptThread)
-      }
-    }
+    runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
   }
 
   def stop(): Unit = {
@@ -192,13 +218,21 @@ private[spark] class Executor(
       serializedTask: ByteBuffer)
     extends Runnable {
 
+    val threadName = s"Executor task launch worker for task $taskId"
+
     /** Whether this task has been killed. */
     @volatile private var killed = false
 
+    @volatile private var threadId: Long = -1
+
+    def getThreadId: Long = threadId
+
     /** Whether this task has been finished. */
     @GuardedBy("TaskRunner.this")
     private var finished = false
 
+    def isFinished: Boolean = synchronized { finished }
+
     /** How much the JVM process has spent in GC when the task starts to run. */
     @volatile var startGCTime: Long = _
 
@@ -229,9 +263,15 @@ private[spark] class Executor(
       // ClosedByInterruptException during execBackend.statusUpdate which causes
       // Executor to crash
       Thread.interrupted()
+      // Notify any waiting TaskReapers. Generally there will only be one reaper per task
but there
+      // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
+      // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
+      notifyAll()
     }
 
     override def run(): Unit = {
+      threadId = Thread.currentThread.getId
+      Thread.currentThread.setName(threadName)
       val threadMXBean = ManagementFactory.getThreadMXBean
       val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
       val deserializeStartTime = System.currentTimeMillis()
@@ -432,6 +472,117 @@ private[spark] class Executor(
   }
 
   /**
+   * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
+   * sending a Thread.interrupt(), and monitoring the task until it finishes.
+   *
+   * Spark's current task cancellation / task killing mechanism is "best effort" because
some tasks
+   * may not be interruptable or may not respond to their "killed" flags being set. If a
significant
+   * fraction of a cluster's task slots are occupied by tasks that have been marked as killed
but
+   * remain running then this can lead to a situation where new jobs and tasks are starved
of
+   * resources that are being used by these zombie tasks.
+   *
+   * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up
zombie
+   * tasks. For backwards-compatibility / backportability this component is disabled by default
+   * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
+   *
+   * A TaskReaper is created for a particular task when that task is killed / cancelled.
Typically
+   * a task will have only one TaskReaper, but it's possible for a task to have up to two
reapers
+   * in case kill is called twice with different values for the `interrupt` parameter.
+   *
+   * Once created, a TaskReaper will run until its supervised task has finished running.
If the
+   * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
+   * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run
indefinitely
+   * if the supervised task never exits.
+   */
+  private class TaskReaper(
+      taskRunner: TaskRunner,
+      val interruptThread: Boolean)
+    extends Runnable {
+
+    private[this] val taskId: Long = taskRunner.taskId
+
+    private[this] val killPollingIntervalMs: Long =
+      conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s")
+
+    private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout",
"-1")
+
+    private[this] val takeThreadDump: Boolean =
+      conf.getBoolean("spark.task.reaper.threadDump", true)
+
+    override def run(): Unit = {
+      val startTimeMs = System.currentTimeMillis()
+      def elapsedTimeMs = System.currentTimeMillis() - startTimeMs
+      def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs >
killTimeoutMs
+      try {
+        // Only attempt to kill the task once. If interruptThread = false then a second kill
+        // attempt would be a no-op and if interruptThread = true then it may not be safe
or
+        // effective to interrupt multiple times:
+        taskRunner.kill(interruptThread = interruptThread)
+        // Monitor the killed task until it exits. The synchronization logic here is complicated
+        // because we don't want to synchronize on the taskRunner while possibly taking a
thread
+        // dump, but we also need to be careful to avoid races between checking whether the
task
+        // has finished and wait()ing for it to finish.
+        var finished: Boolean = false
+        while (!finished && !timeoutExceeded()) {
+          taskRunner.synchronized {
+            // We need to synchronize on the TaskRunner while checking whether the task has
+            // finished in order to avoid a race where the task is marked as finished right
after
+            // we check and before we call wait().
+            if (taskRunner.isFinished) {
+              finished = true
+            } else {
+              taskRunner.wait(killPollingIntervalMs)
+            }
+          }
+          if (taskRunner.isFinished) {
+            finished = true
+          } else {
+            logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms")
+            if (takeThreadDump) {
+              try {
+                Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
+                  if (thread.threadName == taskRunner.threadName) {
+                    logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}")
+                  }
+                }
+              } catch {
+                case NonFatal(e) =>
+                  logWarning("Exception thrown while obtaining thread dump: ", e)
+              }
+            }
+          }
+        }
+
+        if (!taskRunner.isFinished && timeoutExceeded()) {
+          if (isLocal) {
+            logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms;
" +
+              "not killing JVM because we are running in local mode.")
+          } else {
+            // In non-local-mode, the exception thrown here will bubble up to the uncaught
exception
+            // handler and cause the executor JVM to exit.
+            throw new SparkException(
+              s"Killing executor JVM because killed task $taskId could not be stopped within
" +
+                s"$killTimeoutMs ms.")
+          }
+        }
+      } finally {
+        // Clean up entries in the taskReaperForTask map.
+        taskReaperForTask.synchronized {
+          taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
+            if (taskReaperInMap eq this) {
+              taskReaperForTask.remove(taskId)
+            } else {
+              // This must have been a TaskReaper where interruptThread == false where a
subsequent
+              // killTask() call for the same task had interruptThread == true and overwrote
the
+              // map entry.
+            }
+          }
+        }
+      }
+    }
+  }
+
+  /**
    * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any
classes
    * created by the interpreter to the search path
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/fa829ce2/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c6ad154..078cc3d 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.util
 
 import java.io._
-import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo}
+import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
 import java.net._
 import java.nio.ByteBuffer
 import java.nio.channels.Channels
@@ -2131,28 +2131,46 @@ private[spark] object Utils extends Logging {
     // We need to filter out null values here because dumpAllThreads() may return null array
     // elements for threads that are dead / don't exist.
     val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_
!= null)
-    threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
-      val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame ->
m).toMap
-      val stackTrace = threadInfo.getStackTrace.map { frame =>
-        monitors.get(frame) match {
-          case Some(monitor) =>
-            monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
-          case None =>
-            frame.toString
-        }
-      }.mkString("\n")
-
-      // use a set to dedup re-entrant locks that are held at multiple places
-      val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString)
-          ++ threadInfo.getLockedMonitors.map(_.lockString)
-        ).toSet
+    threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
+  }
 
-      ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState,
-        stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
-        Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq)
+  def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
+    if (threadId <= 0) {
+      None
+    } else {
+      // The Int.MaxValue here requests the entire untruncated stack trace of the thread:
+      val threadInfo =
+        Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue))
+      threadInfo.map(threadInfoToThreadStackTrace)
     }
   }
 
+  private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
+    val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
+    val stackTrace = threadInfo.getStackTrace.map { frame =>
+      monitors.get(frame) match {
+        case Some(monitor) =>
+          monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
+        case None =>
+          frame.toString
+      }
+    }.mkString("\n")
+
+    // use a set to dedup re-entrant locks that are held at multiple places
+    val heldLocks =
+      (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet
+
+    ThreadStackTrace(
+      threadId = threadInfo.getThreadId,
+      threadName = threadInfo.getThreadName,
+      threadState = threadInfo.getThreadState,
+      stackTrace = stackTrace,
+      blockedByThreadId =
+        if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId),
+      blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""),
+      holdingLocks = heldLocks.toSeq)
+  }
+
   /**
    * Convert all spark properties set in the given SparkConf to a sequence of java options.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/fa829ce2/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index a3490fc..99150a1 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
     assert(jobB.get() === 100)
   }
 
+  test("task reaper kills JVM if killed tasks keep running for too long") {
+    val conf = new SparkConf()
+      .set("spark.task.reaper.enabled", "true")
+      .set("spark.task.reaper.killTimeout", "5s")
+    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
+
+    // Add a listener to release the semaphore once any tasks are launched.
+    val sem = new Semaphore(0)
+    sc.addSparkListener(new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart) {
+        sem.release()
+      }
+    })
+
+    // jobA is the one to be cancelled.
+    val jobA = Future {
+      sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
+      sc.parallelize(1 to 10000, 2).map { i =>
+        while (true) { }
+      }.count()
+    }
+
+    // Block until both tasks of job A have started and cancel job A.
+    sem.acquire(2)
+    // Small delay to ensure tasks actually start executing the task body
+    Thread.sleep(1000)
+
+    sc.clearJobGroup()
+    val jobB = sc.parallelize(1 to 100, 2).countAsync()
+    sc.cancelJobGroup("jobA")
+    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
+    assert(e.getMessage contains "cancel")
+
+    // Once A is cancelled, job B should finish fairly quickly.
+    assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
+  }
+
+  test("task reaper will not kill JVM if spark.task.killTimeout == -1") {
+    val conf = new SparkConf()
+      .set("spark.task.reaper.enabled", "true")
+      .set("spark.task.reaper.killTimeout", "-1")
+      .set("spark.task.reaper.PollingInterval", "1s")
+      .set("spark.deploy.maxExecutorRetries", "1")
+    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
+
+    // Add a listener to release the semaphore once any tasks are launched.
+    val sem = new Semaphore(0)
+    sc.addSparkListener(new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart) {
+        sem.release()
+      }
+    })
+
+    // jobA is the one to be cancelled.
+    val jobA = Future {
+      sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
+      sc.parallelize(1 to 2, 2).map { i =>
+        val startTime = System.currentTimeMillis()
+        while (System.currentTimeMillis() < startTime + 10000) { }
+      }.count()
+    }
+
+    // Block until both tasks of job A have started and cancel job A.
+    sem.acquire(2)
+    // Small delay to ensure tasks actually start executing the task body
+    Thread.sleep(1000)
+
+    sc.clearJobGroup()
+    val jobB = sc.parallelize(1 to 100, 2).countAsync()
+    sc.cancelJobGroup("jobA")
+    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
+    assert(e.getMessage contains "cancel")
+
+    // Once A is cancelled, job B should finish fairly quickly.
+    assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
+  }
+
   test("two jobs sharing the same stage") {
     // sem1: make sure cancel is issued after some tasks are launched
     // twoJobsSharingStageSemaphore:

http://git-wip-us.apache.org/repos/asf/spark/blob/fa829ce2/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index a8b7197..39bfb3a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1423,6 +1423,48 @@ Apart from these, the following properties are also available, and
may be useful
     Should be greater than or equal to 1. Number of allowed retries = this value - 1.
   </td>
 </tr>
+<tr>
+  <td><code>spark.task.reaper.enabled</code></td>
+  <td>false</td>
+  <td>
+    Enables monitoring of killed / interrupted tasks. When set to true, any task which is
killed
+    will be monitored by the executor until that task actually finishes executing. See the
other
+    <code>spark.task.reaper.*</code> configurations for details on how to control
the exact behavior
+    of this monitoring</code>. When set to false (the default), task killing will use
an older code
+    path which lacks such monitoring.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.task.reaper.pollingInterval</code></td>
+  <td>10s</td>
+  <td>
+    When <code>spark.task.reaper.enabled = true</code>, this setting controls
the frequency at which
+    executors will poll the status of killed tasks. If a killed task is still running when
polled
+    then a warning will be logged and, by default, a thread-dump of the task will be logged
+    (this thread dump can be disabled via the <code>spark.task.reaper.threadDump</code>
setting,
+    which is documented below).
+  </td>
+</tr>
+<tr>
+  <td><code>spark.task.reaper.threadDump</code></td>
+  <td>true</td>
+  <td>
+    When <code>spark.task.reaper.enabled = true</code>, this setting controls
whether task thread
+    dumps are logged during periodic polling of killed tasks. Set this to false to disable
+    collection of thread dumps.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.task.reaper.killTimeout</code></td>
+  <td>-1</td>
+  <td>
+    When <code>spark.task.reaper.enabled = true</code>, this setting specifies
a timeout after
+    which the executor JVM will kill itself if a killed task has not stopped running. The
default
+    value, -1, disables this mechanism and prevents the executor from self-destructing. The
purpose
+    of this setting is to act as a safety-net to prevent runaway uncancellable tasks from
rendering
+    an executor unusable.
+  </td>
+</tr>
 </table>
 
 #### Dynamic Allocation


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message