spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hvanhov...@apache.org
Subject spark git commit: [SPARK-19889][SQL] Make TaskContext callbacks thread safe
Date Wed, 15 Mar 2017 09:46:11 GMT
Repository: spark
Updated Branches:
  refs/heads/master ee36bc1c9 -> 9ff85be3b


[SPARK-19889][SQL] Make TaskContext callbacks thread safe

## What changes were proposed in this pull request?
It is sometimes useful to use multiple threads in a task to parallelize tasks. These threads
might register some completion/failure listeners to clean up when the task completes or fails.
We currently cannot register such a callback and be sure that it will get called, because
the context might be in the process of invoking its callbacks, when the the callback gets
registered.

This PR improves this by making sure that you cannot add a completion/failure listener from
a different thread when the context is being marked as completed/failed in another thread.
This is done by synchronizing these methods on the task context itself.

Failure listeners were called only once. Completion listeners now follow the same pattern;
this lifts the idempotency requirement for completion listeners and makes it easier to implement
them. In some cases we can (accidentally) add a completion/failure listener after the fact,
these listeners will be called immediately in order make sure we can safely clean-up after
a task.

As a result of this change we could make the `failure` and `completed` flags non-volatile.
The `isCompleted()` method now uses synchronization to ensure that updates are visible across
threads.

## How was this patch tested?
Adding tests to `TaskContestSuite` to test adding listeners to a completed/failed context.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #17244 from hvanhovell/SPARK-19889.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9ff85be3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9ff85be3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9ff85be3

Branch: refs/heads/master
Commit: 9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb
Parents: ee36bc1
Author: Herman van Hovell <hvanhovell@databricks.com>
Authored: Wed Mar 15 10:46:05 2017 +0100
Committer: Herman van Hovell <hvanhovell@databricks.com>
Committed: Wed Mar 15 10:46:05 2017 +0100

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskContext.scala    | 16 ++--
 .../org/apache/spark/TaskContextImpl.scala      | 85 +++++++++++++-------
 .../spark/scheduler/TaskContextSuite.scala      | 26 ++++++
 3 files changed, 93 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9ff85be3/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index f0867ec..5acfce1 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable {
 
   /**
    * Adds a (Java friendly) listener to be executed on task completion.
-   * This will be called in all situation - success, failure, or cancellation.
+   * This will be called in all situations - success, failure, or cancellation. Adding a
listener
+   * to an already completed task will result in that listener being called immediately.
+   *
    * An example use is for HadoopRDD to register a callback to close the input stream.
    *
    * Exceptions thrown by the listener will result in failure of the task.
@@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable {
 
   /**
    * Adds a listener in the form of a Scala closure to be executed on task completion.
-   * This will be called in all situations - success, failure, or cancellation.
+   * This will be called in all situations - success, failure, or cancellation. Adding a
listener
+   * to an already completed task will result in that listener being called immediately.
+   *
    * An example use is for HadoopRDD to register a callback to close the input stream.
    *
    * Exceptions thrown by the listener will result in failure of the task.
@@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable {
   }
 
   /**
-   * Adds a listener to be executed on task failure.
-   * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple
times.
+   * Adds a listener to be executed on task failure. Adding a listener to an already failed
task
+   * will result in that listener being called immediately.
    */
   def addTaskFailureListener(listener: TaskFailureListener): TaskContext
 
   /**
-   * Adds a listener to be executed on task failure.
-   * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple
times.
+   * Adds a listener to be executed on task failure.  Adding a listener to an already failed
task
+   * will result in that listener being called immediately.
    */
   def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
     addTaskFailureListener(new TaskFailureListener {

http://git-wip-us.apache.org/repos/asf/spark/blob/9ff85be3/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index dc0d128..ea8dcdf 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -18,6 +18,7 @@
 package org.apache.spark
 
 import java.util.Properties
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util._
 
+/**
+ * A [[TaskContext]] implementation.
+ *
+ * A small note on thread safety. The interrupted & fetchFailed fields are volatile,
this makes
+ * sure that updates are always visible across threads. The complete & failed flags and
their
+ * callbacks are protected by locking on the context instance. For instance, this ensures
+ * that you cannot add a completion listener in one thread while we are completing (and calling
+ * the completion listeners) in another thread. Other state is immutable, however the exposed
+ * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe.
+ */
 private[spark] class TaskContextImpl(
     val stageId: Int,
     val partitionId: Int,
@@ -52,62 +63,79 @@ private[spark] class TaskContextImpl(
   @volatile private var interrupted: Boolean = false
 
   // Whether the task has completed.
-  @volatile private var completed: Boolean = false
+  private var completed: Boolean = false
 
   // Whether the task has failed.
-  @volatile private var failed: Boolean = false
+  private var failed: Boolean = false
+
+  // Throwable that caused the task to fail
+  private var failure: Throwable = _
 
   // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
   // hide the exception.  See SPARK-19276
   @volatile private var _fetchFailedException: Option[FetchFailedException] = None
 
-  override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
-    onCompleteCallbacks += listener
+  @GuardedBy("this")
+  override def addTaskCompletionListener(listener: TaskCompletionListener)
+      : this.type = synchronized {
+    if (completed) {
+      listener.onTaskCompletion(this)
+    } else {
+      onCompleteCallbacks += listener
+    }
     this
   }
 
-  override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
-    onFailureCallbacks += listener
+  @GuardedBy("this")
+  override def addTaskFailureListener(listener: TaskFailureListener)
+      : this.type = synchronized {
+    if (failed) {
+      listener.onTaskFailure(this, failure)
+    } else {
+      onFailureCallbacks += listener
+    }
     this
   }
 
   /** Marks the task as failed and triggers the failure listeners. */
-  private[spark] def markTaskFailed(error: Throwable): Unit = {
-    // failure callbacks should only be called once
+  @GuardedBy("this")
+  private[spark] def markTaskFailed(error: Throwable): Unit = synchronized {
     if (failed) return
     failed = true
-    val errorMsgs = new ArrayBuffer[String](2)
-    // Process failure callbacks in the reverse order of registration
-    onFailureCallbacks.reverse.foreach { listener =>
-      try {
-        listener.onTaskFailure(this, error)
-      } catch {
-        case e: Throwable =>
-          errorMsgs += e.getMessage
-          logError("Error in TaskFailureListener", e)
-      }
-    }
-    if (errorMsgs.nonEmpty) {
-      throw new TaskCompletionListenerException(errorMsgs, Option(error))
+    failure = error
+    invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
+      _.onTaskFailure(this, error)
     }
   }
 
   /** Marks the task as completed and triggers the completion listeners. */
-  private[spark] def markTaskCompleted(): Unit = {
+  @GuardedBy("this")
+  private[spark] def markTaskCompleted(): Unit = synchronized {
+    if (completed) return
     completed = true
+    invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
+      _.onTaskCompletion(this)
+    }
+  }
+
+  private def invokeListeners[T](
+      listeners: Seq[T],
+      name: String,
+      error: Option[Throwable])(
+      callback: T => Unit): Unit = {
     val errorMsgs = new ArrayBuffer[String](2)
-    // Process complete callbacks in the reverse order of registration
-    onCompleteCallbacks.reverse.foreach { listener =>
+    // Process callbacks in the reverse order of registration
+    listeners.reverse.foreach { listener =>
       try {
-        listener.onTaskCompletion(this)
+        callback(listener)
       } catch {
         case e: Throwable =>
           errorMsgs += e.getMessage
-          logError("Error in TaskCompletionListener", e)
+          logError(s"Error in $name", e)
       }
     }
     if (errorMsgs.nonEmpty) {
-      throw new TaskCompletionListenerException(errorMsgs)
+      throw new TaskCompletionListenerException(errorMsgs, error)
     }
   }
 
@@ -116,7 +144,8 @@ private[spark] class TaskContextImpl(
     interrupted = true
   }
 
-  override def isCompleted(): Boolean = completed
+  @GuardedBy("this")
+  override def isCompleted(): Boolean = synchronized(completed)
 
   override def isRunningLocally(): Boolean = false
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9ff85be3/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 7004128..8f576da 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with
LocalSpark
     assert(res === Array("testPropValue,testPropValue"))
   }
 
+  test("immediately call a completion listener if the context is completed") {
+    var invocations = 0
+    val context = TaskContext.empty()
+    context.markTaskCompleted()
+    context.addTaskCompletionListener(_ => invocations += 1)
+    assert(invocations == 1)
+    context.markTaskCompleted()
+    assert(invocations == 1)
+  }
+
+  test("immediately call a failure listener if the context has failed") {
+    var invocations = 0
+    var lastError: Throwable = null
+    val error = new RuntimeException
+    val context = TaskContext.empty()
+    context.markTaskFailed(error)
+    context.addTaskFailureListener { (_, e) =>
+      lastError = e
+      invocations += 1
+    }
+    assert(lastError == error)
+    assert(invocations == 1)
+    context.markTaskFailed(error)
+    assert(lastError == error)
+    assert(invocations == 1)
+  }
 }
 
 private object TaskContextSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message