spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject git commit: [SPARK-3027] TaskContext: tighten visibility and provide Java friendly callback API
Date Fri, 15 Aug 2014 01:37:07 GMT
Repository: spark
Updated Branches:
  refs/heads/master fa5a08e67 -> 655699f8b


[SPARK-3027] TaskContext: tighten visibility and provide Java friendly callback API

Note this also passes the TaskContext itself to the TaskCompletionListener. In the future
we can mark TaskContext with the exception object if exception occurs during task execution.

Author: Reynold Xin <rxin@apache.org>

Closes #1938 from rxin/TaskContext and squashes the following commits:

145de43 [Reynold Xin] Added JavaTaskCompletionListenerImpl for Java API friendly guarantee.
f435ea5 [Reynold Xin] Added license header for TaskCompletionListener.
dc4ed27 [Reynold Xin] [SPARK-3027] TaskContext: tighten the visibility and provide Java friendly
callback API


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/655699f8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/655699f8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/655699f8

Branch: refs/heads/master
Commit: 655699f8b7156e8216431393436368e80626cdb2
Parents: fa5a08e
Author: Reynold Xin <rxin@apache.org>
Authored: Thu Aug 14 18:37:02 2014 -0700
Committer: Reynold Xin <rxin@apache.org>
Committed: Thu Aug 14 18:37:02 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/InterruptibleIterator.scala    |  2 +-
 .../scala/org/apache/spark/TaskContext.scala    | 63 +++++++++++++++++---
 .../org/apache/spark/api/python/PythonRDD.scala | 12 ++--
 .../org/apache/spark/rdd/CheckpointRDD.scala    |  2 +-
 .../scala/org/apache/spark/rdd/HadoopRDD.scala  |  2 +-
 .../scala/org/apache/spark/rdd/JdbcRDD.scala    |  2 +-
 .../org/apache/spark/rdd/NewHadoopRDD.scala     |  2 +-
 .../apache/spark/scheduler/DAGScheduler.scala   |  2 +-
 .../org/apache/spark/scheduler/ResultTask.scala |  2 +-
 .../apache/spark/scheduler/ShuffleMapTask.scala |  2 +-
 .../scala/org/apache/spark/scheduler/Task.scala |  2 +-
 .../spark/util/TaskCompletionListener.scala     | 33 ++++++++++
 .../util/JavaTaskCompletionListenerImpl.java    | 39 ++++++++++++
 .../spark/scheduler/TaskContextSuite.scala      |  2 +-
 14 files changed, 144 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index f40baa8..5c262bc 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate:
Iterator
     // is allowed. The assumption is that Thread.interrupted does not have a memory fence
in read
     // (just a volatile field in C), while context.interrupted is a volatile in the JVM,
which
     // introduces an expensive read fence.
-    if (context.interrupted) {
+    if (context.isInterrupted) {
       throw new TaskKilledException
     } else {
       delegate.hasNext

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 51f40c3..2b99b8a 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.TaskCompletionListener
+
 
 /**
  * :: DeveloperApi ::
  * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ * @param taskMetrics performance metrics of the task
  */
 @DeveloperApi
 class TaskContext(
@@ -39,13 +47,45 @@ class TaskContext(
   def splitId = partitionId
 
   // List of callback functions to execute when the task completes.
-  @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+  @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
 
   // Whether the corresponding task has been killed.
-  @volatile var interrupted: Boolean = false
+  @volatile private var interrupted: Boolean = false
+
+  // Whether the task has completed.
+  @volatile private var completed: Boolean = false
+
+  /** Checks whether the task has completed. */
+  def isCompleted: Boolean = completed
 
-  // Whether the task has completed, before the onCompleteCallbacks are executed.
-  @volatile var completed: Boolean = false
+  /** Checks whether the task has been killed. */
+  def isInterrupted: Boolean = interrupted
+
+  // TODO: Also track whether the task has completed successfully or with exception.
+
+  /**
+   * Add a (Java friendly) listener to be executed on task completion.
+   * This will be called in all situation - success, failure, or cancellation.
+   *
+   * An example use is for HadoopRDD to register a callback to close the input stream.
+   */
+  def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
+    onCompleteCallbacks += listener
+    this
+  }
+
+  /**
+   * Add a listener in the form of a Scala closure to be executed on task completion.
+   * This will be called in all situation - success, failure, or cancellation.
+   *
+   * An example use is for HadoopRDD to register a callback to close the input stream.
+   */
+  def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
+    onCompleteCallbacks += new TaskCompletionListener {
+      override def onTaskCompletion(context: TaskContext): Unit = f(context)
+    }
+    this
+  }
 
   /**
    * Add a callback function to be executed on task completion. An example use
@@ -53,13 +93,22 @@ class TaskContext(
    * Will be called in any situation - success, failure, or cancellation.
    * @param f Callback function.
    */
+  @deprecated("use addTaskCompletionListener", "1.1.0")
   def addOnCompleteCallback(f: () => Unit) {
-    onCompleteCallbacks += f
+    onCompleteCallbacks += new TaskCompletionListener {
+      override def onTaskCompletion(context: TaskContext): Unit = f()
+    }
   }
 
-  def executeOnCompleteCallbacks() {
+  /** Marks the task as completed and triggers the listeners. */
+  private[spark] def markTaskCompleted(): Unit = {
     completed = true
     // Process complete callbacks in the reverse order of registration
-    onCompleteCallbacks.reverse.foreach { _() }
+    onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
+  }
+
+  /** Marks the task for interruption, i.e. cancellation. */
+  private[spark] def markInterrupted(): Unit = {
+    interrupted = true
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0b5322c..fefe1cb 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -68,7 +68,7 @@ private[spark] class PythonRDD(
     // Start a thread to feed the process input from our parent's iterator
     val writerThread = new WriterThread(env, worker, split, context)
 
-    context.addOnCompleteCallback { () =>
+    context.addTaskCompletionListener { context =>
       writerThread.shutdownOnTaskCompletion()
 
       // Cleanup the worker socket. This will also cause the Python worker to exit.
@@ -137,7 +137,7 @@ private[spark] class PythonRDD(
           }
         } catch {
 
-          case e: Exception if context.interrupted =>
+          case e: Exception if context.isInterrupted =>
             logDebug("Exception thrown after task interruption", e)
             throw new TaskKilledException
 
@@ -176,7 +176,7 @@ private[spark] class PythonRDD(
 
     /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup.
*/
     def shutdownOnTaskCompletion() {
-      assert(context.completed)
+      assert(context.isCompleted)
       this.interrupt()
     }
 
@@ -209,7 +209,7 @@ private[spark] class PythonRDD(
         PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
         dataOut.flush()
       } catch {
-        case e: Exception if context.completed || context.interrupted =>
+        case e: Exception if context.isCompleted || context.isInterrupted =>
           logDebug("Exception thrown after task completion (likely due to cleanup)", e)
 
         case e: Exception =>
@@ -235,10 +235,10 @@ private[spark] class PythonRDD(
     override def run() {
       // Kill the worker if it is interrupted, checking until task completion.
       // TODO: This has a race condition if interruption occurs, as completed may still become
true.
-      while (!context.interrupted && !context.completed) {
+      while (!context.isInterrupted && !context.isCompleted) {
         Thread.sleep(2000)
       }
-      if (!context.completed) {
+      if (!context.isCompleted) {
         try {
           logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
           env.destroyPythonWorker(pythonExec, envVars.toMap, worker)

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 34c51b8..2093878 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging {
     val deserializeStream = serializer.deserializeStream(fileInputStream)
 
     // Register an on-task-completion callback to close the input stream.
-    context.addOnCompleteCallback(() => deserializeStream.close())
+    context.addTaskCompletionListener(context => deserializeStream.close())
 
     deserializeStream.asIterator.asInstanceOf[Iterator[T]]
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 8d92ea0..c862331 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -197,7 +197,7 @@ class HadoopRDD[K, V](
       reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
 
       // Register an on-task-completion callback to close the input stream.
-      context.addOnCompleteCallback{ () => closeIfNeeded() }
+      context.addTaskCompletionListener{ context => closeIfNeeded() }
       val key: K = reader.createKey()
       val value: V = reader.createValue()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 8947e66..0e38f22 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag](
   }
 
   override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
-    context.addOnCompleteCallback{ () => closeIfNeeded() }
+    context.addTaskCompletionListener{ context => closeIfNeeded() }
     val part = thePart.asInstanceOf[JdbcPartition]
     val conn = getConnection()
     val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7dfec9a..58f707b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -129,7 +129,7 @@ class NewHadoopRDD[K, V](
       context.taskMetrics.inputMetrics = Some(inputMetrics)
 
       // Register an on-task-completion callback to close the input stream.
-      context.addOnCompleteCallback(() => close())
+      context.addTaskCompletionListener(context => close())
       var havePair = false
       var finished = false
 

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 36bbaaa..b86cfbf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,7 +634,7 @@ class DAGScheduler(
         val result = job.func(taskContext, rdd.iterator(split, taskContext))
         job.listener.taskSucceeded(0, result)
       } finally {
-        taskContext.executeOnCompleteCallbacks()
+        taskContext.markTaskCompleted()
       }
     } catch {
       case e: Exception =>

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index d09fd7a..2ccbd8e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U](
     try {
       func(context, rdd.iterator(partition, context))
     } finally {
-      context.executeOnCompleteCallbacks()
+      context.markTaskCompleted()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 11255c0..381eff2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
         }
         throw e
     } finally {
-      context.executeOnCompleteCallbacks()
+      context.markTaskCompleted()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index cbe0bc0..6aa0cca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId:
Int) ex
   def kill(interruptThread: Boolean) {
     _killed = true
     if (context != null) {
-      context.interrupted = true
+      context.markInterrupted()
     }
     if (interruptThread && taskThread != null) {
       taskThread.interrupt()

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
new file mode 100644
index 0000000..c1b8bf0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.EventListener
+
+import org.apache.spark.TaskContext
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Listener providing a callback function to invoke when a task's execution completes.
+ */
+@DeveloperApi
+trait TaskCompletionListener extends EventListener {
+  def onTaskCompletion(context: TaskContext)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
new file mode 100644
index 0000000..af34cdb
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util;
+
+import org.apache.spark.TaskContext;
+
+
+/**
+ * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener
and
+ * TaskContext is Java friendly.
+ */
+public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
+
+  @Override
+  public void onTaskCompletion(TaskContext context) {
+    context.isCompleted();
+    context.isInterrupted();
+    context.stageId();
+    context.partitionId();
+    context.runningLocally();
+    context.taskMetrics();
+    context.addTaskCompletionListener(this);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/655699f8/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 270f7e6..db2ad82 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
     val rdd = new RDD[String](sc, List()) {
       override def getPartitions = Array[Partition](StubPartition(0))
       override def compute(split: Partition, context: TaskContext) = {
-        context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
+        context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
         sys.error("failed")
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message