nemo-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jan...@apache.org
Subject [incubator-nemo] branch master updated: [NEMO-122] Manage Task/Stage/Job states in one place (#52)
Date Mon, 25 Jun 2018 03:55:04 GMT
This is an automated email from the ASF dual-hosted git repository.

jangho pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ac2759  [NEMO-122] Manage Task/Stage/Job states in one place (#52)
3ac2759 is described below

commit 3ac27597fd1e248ebbea9d07ad13070ff07502b0
Author: John Yang <johnyangk@gmail.com>
AuthorDate: Mon Jun 25 12:55:00 2018 +0900

    [NEMO-122] Manage Task/Stage/Job states in one place (#52)
    
    JIRA: [NEMO-122: Manage Task/Stage/Job states in one place](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-122)
    
    **Major changes:**
    - State updates to Stage and Job are handled automatically inside the JobStateManager
    - Only the Task state updates are fed into the JobStateManager from the outside.
    
    **Minor changes to note:**
    - Removes redundant data structures in JobStateManager
    
    **Tests for the changes:**
    - Runtime tests have been modified to reflect the changes in the state management
    
    **Other comments:**
    - This PR tries to maintain the existing scheduler design as much as possible, and only change the state management part
    - More cleanups and reorganizations of the states will come in the follow-up PRs for issues like NEMO-55 (https://issues.apache.org/jira/projects/NEMO/issues/NEMO-50)
    
    resolves [NEMO-122](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-122)
---
 .../java/edu/snu/nemo/client/DriverEndpoint.java   |   6 +-
 .../snu/nemo/runtime/common/state/StageState.java  |  12 +-
 .../snu/nemo/runtime/common/state/TaskState.java   |  38 +--
 .../snu/nemo/runtime/master/JobStateManager.java   | 376 +++++++++------------
 .../edu/snu/nemo/runtime/master/RuntimeMaster.java |   2 +-
 .../master/scheduler/BatchSingleJobScheduler.java  |  86 ++---
 .../nemo/runtime/master/scheduler/Scheduler.java   |  12 +-
 .../runtime/master/scheduler/SchedulerRunner.java  |   3 +-
 .../nemo/runtime/master/JobStateManagerTest.java   |  51 ++-
 .../scheduler/BatchSingleJobSchedulerTest.java     |   7 +-
 .../master/scheduler/FaultToleranceTest.java       |  74 ++--
 .../master/scheduler/SchedulerTestUtil.java        |  24 +-
 .../snu/nemo/tests/client/ClientEndpointTest.java  |  30 +-
 13 files changed, 308 insertions(+), 413 deletions(-)

diff --git a/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java b/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java
index 523463f..fc33eea 100644
--- a/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java
+++ b/client/src/main/java/edu/snu/nemo/client/DriverEndpoint.java
@@ -54,7 +54,7 @@ public final class DriverEndpoint {
    * @return the current state of the running job.
    */
   JobState.State getState() {
-    return (JobState.State) jobStateManager.getJobState().getStateMachine().getCurrentState();
+    return jobStateManager.getJobState();
   }
 
   /**
@@ -67,7 +67,7 @@ public final class DriverEndpoint {
    */
   JobState.State waitUntilFinish(final long timeout,
                                  final TimeUnit unit) {
-    return (JobState.State) jobStateManager.waitUntilFinish(timeout, unit).getStateMachine().getCurrentState();
+    return jobStateManager.waitUntilFinish(timeout, unit);
   }
 
   /**
@@ -76,6 +76,6 @@ public final class DriverEndpoint {
    * @return the final state of this job.
    */
   JobState.State waitUntilFinish() {
-    return (JobState.State) jobStateManager.waitUntilFinish().getStateMachine().getCurrentState();
+    return jobStateManager.waitUntilFinish();
   }
 }
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java
index 81a0980..6bf3380 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java
@@ -35,18 +35,15 @@ public final class StageState {
     stateMachineBuilder.addState(State.EXECUTING, "The stage is executing.");
     stateMachineBuilder.addState(State.COMPLETE, "All of this stage's tasks have completed.");
     stateMachineBuilder.addState(State.FAILED_RECOVERABLE, "Stage failed, but is recoverable.");
-    stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE, "Stage failed, and is unrecoverable. The job will fail.");
 
     // Add transitions
     stateMachineBuilder.addTransition(State.READY, State.EXECUTING,
         "The stage can now schedule its tasks");
-    stateMachineBuilder.addTransition(State.READY, State.FAILED_UNRECOVERABLE,
-        "Job Failure");
+    stateMachineBuilder.addTransition(State.READY, State.FAILED_RECOVERABLE,
+        "Recoverable failure");
 
     stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE,
         "All tasks complete");
-    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE,
-        "Unrecoverable failure in a task");
     stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE,
         "Recoverable failure in a task");
 
@@ -55,8 +52,8 @@ public final class StageState {
 
     stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY,
         "Recoverable stage failure");
-    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.FAILED_UNRECOVERABLE,
-        "");
+    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.EXECUTING,
+        "Recoverable stage failure");
 
     stateMachineBuilder.setInitialState(State.READY);
 
@@ -75,7 +72,6 @@ public final class StageState {
     EXECUTING,
     COMPLETE,
     FAILED_RECOVERABLE,
-    FAILED_UNRECOVERABLE
   }
 
   @Override
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
index e201e32..b47696a 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
@@ -35,37 +35,35 @@ public final class TaskState {
     stateMachineBuilder.addState(State.EXECUTING, "The task is executing.");
     stateMachineBuilder.addState(State.COMPLETE, "The task has completed.");
     stateMachineBuilder.addState(State.FAILED_RECOVERABLE, "Task failed, but is recoverable.");
-    stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE,
-        "Task failed, and is unrecoverable. The job will fail.");
+    stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE, "Task failed, and is unrecoverable. The job will fail.");
     stateMachineBuilder.addState(State.ON_HOLD, "The task is paused for dynamic optimization.");
 
-    // Add transitions
-    stateMachineBuilder.addTransition(State.READY, State.EXECUTING,
-        "Scheduling to executor");
+    // From NOT_AVAILABLE
+    stateMachineBuilder.addTransition(State.READY, State.EXECUTING, "Scheduling to executor");
     stateMachineBuilder.addTransition(State.READY, State.FAILED_RECOVERABLE,
         "Stage Failure by a recoverable failure in another task");
-    stateMachineBuilder.addTransition(State.READY, State.FAILED_UNRECOVERABLE,
-        "Stage Failure");
-
-    stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE,
-        "All tasks complete");
-    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE,
-        "Unrecoverable failure in a task/Executor failure");
-    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE,
-        "Recoverable failure in a task/Container failure");
+
+    // From EXECUTING
+    stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE, "Task completed normally");
+    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE, "Unrecoverable failure");
+    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE, "Recoverable failure");
     stateMachineBuilder.addTransition(State.EXECUTING, State.ON_HOLD, "Task paused for dynamic optimization");
-    stateMachineBuilder.addTransition(State.ON_HOLD, State.COMPLETE, "Task completed after dynamic optimization");
 
+    // From ON HOLD
+    stateMachineBuilder.addTransition(State.ON_HOLD, State.COMPLETE, "Task completed after being on hold");
+    stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED_UNRECOVERABLE, "Unrecoverable failure");
+    stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED_RECOVERABLE, "Recoverable failure");
+
+    // From COMPLETE
+    stateMachineBuilder.addTransition(State.COMPLETE, State.EXECUTING, "Completed before, but re-execute");
     stateMachineBuilder.addTransition(State.COMPLETE, State.FAILED_RECOVERABLE,
         "Recoverable failure in a task/Container failure");
 
-    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY,
-        "Recovered from failure and is ready");
-    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.FAILED_UNRECOVERABLE,
-        "");
 
-    stateMachineBuilder.setInitialState(State.READY);
+    // From FAILED_RECOVERABLE
+    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY, "Recovered from failure and is ready");
 
+    stateMachineBuilder.setInitialState(State.READY);
     return stateMachineBuilder.build();
   }
 
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java
index 0928caf..220e585 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java
@@ -43,22 +43,24 @@ import edu.snu.nemo.runtime.common.state.TaskState;
 import org.apache.reef.annotations.audience.DriverSide;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import java.util.stream.Collectors;
+
+import javax.annotation.concurrent.ThreadSafe;
 import java.util.stream.IntStream;
 
 import static edu.snu.nemo.common.dag.DAG.EMPTY_DAG_DIRECTORY;
 
 /**
- * Manages the states related to a job.
- * This class can be used to track a job's execution status to task level in the future.
- * The methods of this class are synchronized.
+ * Maintains three levels of state machines (JobState, StageState, and TaskState) of a physical plan.
+ * The main API this class provides is onTaskStateReportFromExecutor(), which directly changes a TaskState.
+ * JobState and StageState are updated internally in the class, and can only be read from the outside.
+ *
+ * (CONCURRENCY) The public methods of this class are synchronized.
  */
 @DriverSide
+@ThreadSafe
 public final class JobStateManager {
   private static final Logger LOG = LoggerFactory.getLogger(JobStateManager.class.getName());
-
   private final String jobId;
-
   private final int maxScheduleAttempt;
 
   /**
@@ -75,37 +77,19 @@ public final class JobStateManager {
   private final Map<String, Integer> taskIdToCurrentAttempt;
 
   /**
-   * Keeps track of the number of schedule attempts for each stage.
-   */
-  private final Map<String, Integer> scheduleAttemptIdxByStage;
-
-  /**
    * Represents the job to manage.
    */
   private final PhysicalPlan physicalPlan;
 
   /**
-   * Used to track stage completion status.
-   * All task ids are added to the set when the a stage begins executing.
-   * Each task id is removed upon completion,
-   * therefore indicating the stage's completion when this set becomes empty.
-   */
-  private final Map<String, Set<String>> stageIdToRemainingTaskSet;
-
-  /**
-   * Used to track job completion status.
-   * All stage ids are added to the set when the this job begins executing.
-   * Each stage id is removed upon completion,
-   * therefore indicating the job's completion when this set becomes empty.
-   */
-  private final Set<String> currentJobStageIds;
-
-  /**
    * A lock and condition to check whether the job is finished or not.
    */
   private final Lock finishLock;
   private final Condition jobFinishedCondition;
 
+  /**
+   * For metrics.
+   */
   private final MetricMessageHandler metricMessageHandler;
   private final Map<String, MetricDataBuilder> metricDataBuilderMap;
 
@@ -121,9 +105,6 @@ public final class JobStateManager {
     this.idToStageStates = new HashMap<>();
     this.idToTaskStates = new HashMap<>();
     this.taskIdToCurrentAttempt = new HashMap<>();
-    this.scheduleAttemptIdxByStage = new HashMap<>();
-    this.stageIdToRemainingTaskSet = new HashMap<>();
-    this.currentJobStageIds = new HashSet<>();
     this.finishLock = new ReentrantLock();
     this.jobFinishedCondition = finishLock.newCondition();
     this.metricDataBuilderMap = new HashMap<>();
@@ -139,7 +120,6 @@ public final class JobStateManager {
 
     // Initialize the states for the job down to task-level.
     physicalPlan.getStageDAG().topologicalDo(stage -> {
-      currentJobStageIds.add(stage.getId());
       idToStageStates.put(stage.getId(), new StageState());
       stage.getTaskIds().forEach(taskId -> {
         idToTaskStates.put(taskId, new TaskState());
@@ -179,216 +159,176 @@ public final class JobStateManager {
   }
 
   /**
-   * Updates the state of the job.
-   * @param newState of the job.
+   * Updates the state of a task.
+   * Task state changes can occur both in master and executor.
+   * State changes that occur in master are
+   * initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler}.
+   * State changes that occur in executors are sent to master as a control message,
+   * and the call to this method is initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler}
+   * when the message/event is received.
+   *
+   * @param taskId  the ID of the task.
+   * @param newTaskState     the new state of the task.
    */
-  public synchronized void onJobStateChanged(final JobState.State newState) {
+  public synchronized void onTaskStateChanged(final String taskId, final TaskState.State newTaskState) {
+    // Change task state
+    final StateMachine taskState = idToTaskStates.get(taskId).getStateMachine();
+    LOG.debug("Task State Transition: id {}, from {} to {}",
+        new Object[]{taskId, taskState.getCurrentState(), newTaskState});
+
+    taskState.setState(newTaskState);
+
+    // Handle metrics
     final Map<String, Object> metric = new HashMap<>();
+    switch (newTaskState) {
+      case ON_HOLD:
+      case COMPLETE:
+      case FAILED_UNRECOVERABLE:
+      case FAILED_RECOVERABLE:
+        metric.put("ToState", newTaskState);
+        endMeasurement(taskId, metric);
+        break;
+      case EXECUTING:
+        metric.put("FromState", newTaskState);
+        beginMeasurement(taskId, metric);
+        break;
+      case READY:
+        final int currentAttempt = taskIdToCurrentAttempt.get(taskId) + 1;
+        metric.put("ScheduleAttempt", currentAttempt);
+        if (currentAttempt <= maxScheduleAttempt) {
+          taskIdToCurrentAttempt.put(taskId, currentAttempt);
+        } else {
+          throw new SchedulingException(new Throwable("Exceeded max number of scheduling attempts for " + taskId));
+        }
+        break;
+      default:
+        throw new UnknownExecutionStateException(new Throwable("This task state is unknown"));
+    }
 
-    if (newState == JobState.State.EXECUTING) {
-      LOG.debug("Executing Job ID {}...", this.jobId);
-      jobState.getStateMachine().setState(newState);
-      metric.put("FromState", newState);
-      beginMeasurement(jobId, metric);
-    } else if (newState == JobState.State.COMPLETE || newState == JobState.State.FAILED) {
-      LOG.debug("Job ID {} {}!", new Object[]{jobId, newState});
-      // Awake all threads waiting the finish of this job.
-      finishLock.lock();
-      try {
-        jobState.getStateMachine().setState(newState);
-        metric.put("ToState", newState);
-        endMeasurement(jobId, metric);
+    // Change stage state, if needed
+    final String stageId = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
+    final List<String> tasksOfThisStage = physicalPlan.getStageDAG().getVertexById(stageId).getTaskIds();
+    final long numOfCompletedOrOnHoldTasksInThisStage = tasksOfThisStage
+        .stream()
+        .map(this::getTaskState)
+        .filter(state -> state.equals(TaskState.State.COMPLETE) || state.equals(TaskState.State.ON_HOLD))
+        .count();
+    switch (newTaskState) {
+      case READY:
+        onStageStateChanged(stageId, StageState.State.READY);
+        break;
+      case EXECUTING:
+        onStageStateChanged(stageId, StageState.State.EXECUTING);
+        break;
+      case FAILED_RECOVERABLE:
+        onStageStateChanged(stageId, StageState.State.FAILED_RECOVERABLE);
+        break;
+      case COMPLETE:
+      case ON_HOLD:
+        if (numOfCompletedOrOnHoldTasksInThisStage == tasksOfThisStage.size()) {
+          onStageStateChanged(stageId, StageState.State.COMPLETE);
+        }
+        break;
+      case FAILED_UNRECOVERABLE:
+        break;
+      default:
+        throw new UnknownExecutionStateException(new Throwable("This task state is unknown"));
+    }
 
-        jobFinishedCondition.signalAll();
-      } finally {
-        finishLock.unlock();
-      }
-    } else {
-      throw new IllegalStateTransitionException(new Exception("Illegal Job State Transition"));
+    // Log not-yet-completed tasks for us to track progress
+    if (newTaskState.equals(TaskState.State.COMPLETE)) {
+      LOG.info("{}: {} Task(s) to go", stageId, tasksOfThisStage.size() - numOfCompletedOrOnHoldTasksInThisStage);
     }
   }
 
   /**
+   * (PRIVATE METHOD)
    * Updates the state of a stage.
-   * Stage state changes only occur in master.
    * @param stageId of the stage.
-   * @param newState of the stage.
+   * @param newStageState of the stage.
    */
-  public synchronized void onStageStateChanged(final String stageId, final StageState.State newState) {
+  private void onStageStateChanged(final String stageId, final StageState.State newStageState) {
+    if (newStageState.equals(getStageState(stageId))) {
+      // Ignore duplicate state updates
+      return;
+    }
+
+    // Change stage state
     final StateMachine stageStateMachine = idToStageStates.get(stageId).getStateMachine();
     LOG.debug("Stage State Transition: id {} from {} to {}",
-        new Object[]{stageId, stageStateMachine.getCurrentState(), newState});
-    stageStateMachine.setState(newState);
-    final Map<String, Object> metric = new HashMap<>();
-
-    if (newState == StageState.State.EXECUTING) {
-      if (scheduleAttemptIdxByStage.containsKey(stageId)) {
-        final int numAttempts = scheduleAttemptIdxByStage.get(stageId);
-
-        if (numAttempts < maxScheduleAttempt) {
-          scheduleAttemptIdxByStage.put(stageId, numAttempts + 1);
-        } else {
-          throw new SchedulingException(
-              new Throwable("Exceeded max number of scheduling attempts for " + stageId));
-        }
-      } else {
-        scheduleAttemptIdxByStage.put(stageId, 1);
-      }
+        new Object[]{stageId, stageStateMachine.getCurrentState(), newStageState});
+    stageStateMachine.setState(newStageState);
 
-      metric.put("ScheduleAttempt", scheduleAttemptIdxByStage.get(stageId));
-      metric.put("FromState", newState);
+    // Metric handling
+    final Map<String, Object> metric = new HashMap<>();
+    if (newStageState == StageState.State.EXECUTING) {
+      metric.put("FromState", newStageState);
       beginMeasurement(stageId, metric);
-
-      // if there exists a mapping, this state change is from a failed_recoverable stage,
-      // and there may be tasks that do not need to be re-executed.
-      if (!stageIdToRemainingTaskSet.containsKey(stageId)) {
-        for (final Stage stage : physicalPlan.getStageDAG().getVertices()) {
-          if (stage.getId().equals(stageId)) {
-            Set<String> remainingTaskIds = new HashSet<>();
-            remainingTaskIds.addAll(
-                stage.getTaskIds().stream().collect(Collectors.toSet()));
-            stageIdToRemainingTaskSet.put(stageId, remainingTaskIds);
-            break;
-          }
-        }
-      }
-    } else if (newState == StageState.State.COMPLETE) {
-      metric.put("ToState", newState);
+    } else if (newStageState == StageState.State.COMPLETE) {
+      metric.put("ToState", newStageState);
       endMeasurement(stageId, metric);
+    }
 
-      currentJobStageIds.remove(stageId);
-      if (currentJobStageIds.isEmpty()) {
-        onJobStateChanged(JobState.State.COMPLETE);
-      }
-    } else if (newState == StageState.State.FAILED_RECOVERABLE) {
-      metric.put("ToState", newState);
-      endMeasurement(stageId, metric);
-      currentJobStageIds.add(stageId);
-    } else if (newState == StageState.State.FAILED_UNRECOVERABLE) {
-      metric.put("ToState", newState);
-      endMeasurement(stageId, metric);
+    // Change job state if needed
+    final boolean allStagesCompleted = idToStageStates.values().stream().allMatch(state ->
+        state.getStateMachine().getCurrentState().equals(StageState.State.COMPLETE));
+
+    // (1) Job becomes EXECUTING if not already
+    if (newStageState.equals(StageState.State.EXECUTING)
+        && !getJobState().equals(JobState.State.EXECUTING)) {
+      onJobStateChanged(JobState.State.EXECUTING);
+    }
+    // (2) Job becomes COMPLETE
+    if (allStagesCompleted) {
+      onJobStateChanged(JobState.State.COMPLETE);
     }
   }
 
   /**
-   * Updates the state of a task.
-   * Task state changes can occur both in master and executor.
-   * State changes that occur in master are
-   * initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler}.
-   * State changes that occur in executors are sent to master as a control message,
-   * and the call to this method is initiated in {@link edu.snu.nemo.runtime.master.scheduler.BatchSingleJobScheduler}
-   * when the message/event is received.
-   *
-   * @param taskId  the ID of the task.
-   * @param newState     the new state of the task.
+   * (PRIVATE METHOD)
+   * Updates the state of the job.
+   * @param newState of the job.
    */
-  public synchronized void onTaskStateChanged(final String taskId, final TaskState.State newState) {
-    final StateMachine taskState = idToTaskStates.get(taskId).getStateMachine();
-    final String stageId = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
-
-    LOG.debug("Task State Transition: id {}, from {} to {}",
-        new Object[]{taskId, taskState.getCurrentState(), newState});
-    final Map<String, Object> metric = new HashMap<>();
-
-    switch (newState) {
-    case ON_HOLD:
-    case COMPLETE:
-      taskState.setState(newState);
-      metric.put("ToState", newState);
-      endMeasurement(taskId, metric);
+  private void onJobStateChanged(final JobState.State newState) {
+    if (newState.equals(getJobState())) {
+      // Ignore duplicate state updates
+      return;
+    }
 
-      if (stageIdToRemainingTaskSet.containsKey(stageId)) {
-        final Set<String> remainingTasks = stageIdToRemainingTaskSet.get(stageId);
-        LOG.info("{}: {} Task(s) to go", stageId, remainingTasks.size());
-        remainingTasks.remove(taskId);
+    jobState.getStateMachine().setState(newState);
 
-        if (remainingTasks.isEmpty()) {
-          onStageStateChanged(stageId, StageState.State.COMPLETE);
-        }
-      } else {
-        throw new IllegalStateTransitionException(
-            new Throwable("The stage has not yet been submitted for execution"));
-      }
-      break;
-    case EXECUTING:
-      taskState.setState(newState);
+    final Map<String, Object> metric = new HashMap<>();
+    if (newState == JobState.State.EXECUTING) {
+      LOG.debug("Executing Job ID {}...", this.jobId);
       metric.put("FromState", newState);
-      beginMeasurement(taskId, metric);
-      break;
-    case FAILED_RECOVERABLE:
-      // Multiple calls to set a task's state to failed_recoverable can occur when
-      // a task is made failed_recoverable early by another task's failure detection in the same stage
-      // and the task finds itself failed_recoverable later, propagating the state change event only then.
-      if (taskState.getCurrentState() != TaskState.State.FAILED_RECOVERABLE) {
-        taskState.setState(newState);
-        metric.put("ToState", newState);
-        endMeasurement(taskId, metric);
-
-        // Mark this stage as failed_recoverable as long as it contains at least one failed_recoverable task
-        if (idToStageStates.get(stageId).getStateMachine().getCurrentState() != StageState.State.FAILED_RECOVERABLE) {
-          onStageStateChanged(stageId, StageState.State.FAILED_RECOVERABLE);
-        }
-
-        if (stageIdToRemainingTaskSet.containsKey(stageId)) {
-          stageIdToRemainingTaskSet.get(stageId).add(taskId);
-        } else {
-          throw new IllegalStateTransitionException(
-              new Throwable("The stage has not yet been submitted for execution"));
-        }
+      beginMeasurement(jobId, metric);
+    } else if (newState == JobState.State.COMPLETE || newState == JobState.State.FAILED) {
+      LOG.debug("Job ID {} {}!", new Object[]{jobId, newState});
 
-        // We'll recover and retry this task
-        taskIdToCurrentAttempt.put(taskId, taskIdToCurrentAttempt.get(taskId) + 1);
-      } else {
-        LOG.info("{} state is already FAILED_RECOVERABLE. Skipping this event.",
-            taskId);
-      }
-      break;
-    case READY:
-      taskState.setState(newState);
-      break;
-    case FAILED_UNRECOVERABLE:
-      taskState.setState(newState);
+      // Awake all threads waiting the finish of this job.
+      finishLock.lock();
       metric.put("ToState", newState);
-      endMeasurement(taskId, metric);
-      break;
-    default:
-      throw new UnknownExecutionStateException(new Throwable("This task state is unknown"));
-    }
-  }
+      endMeasurement(jobId, metric);
 
-  public synchronized boolean checkStageCompletion(final String stageId) {
-    return stageIdToRemainingTaskSet.get(stageId).isEmpty();
-  }
-
-  public synchronized boolean checkJobTermination() {
-    final Enum currentState = jobState.getStateMachine().getCurrentState();
-    return (currentState == JobState.State.COMPLETE || currentState == JobState.State.FAILED);
-  }
-
-  public synchronized int getAttemptCountForStage(final String stageId) {
-    if (scheduleAttemptIdxByStage.containsKey(stageId)) {
-      return scheduleAttemptIdxByStage.get(stageId);
+      try {
+        jobFinishedCondition.signalAll();
+      } finally {
+        finishLock.unlock();
+      }
     } else {
-      throw new IllegalStateException("No mapping for this stage's attemptIdx, an inconsistent state occurred.");
+      throw new IllegalStateTransitionException(new Exception("Illegal Job State Transition"));
     }
   }
 
-  public synchronized int getCurrentAttemptIndexForTask(final String taskId) {
-    if (taskIdToCurrentAttempt.containsKey(taskId)) {
-      return taskIdToCurrentAttempt.get(taskId);
-    } else {
-      throw new IllegalStateException("No mapping for this task's attemptIdx, an inconsistent state occurred.");
-    }
-  }
 
   /**
    * Wait for this job to be finished and return the final state.
    * @return the final state of this job.
    */
-  public JobState waitUntilFinish() {
+  public JobState.State waitUntilFinish() {
     finishLock.lock();
     try {
-      if (!checkJobTermination()) {
+      if (!isJobDone()) {
         jobFinishedCondition.await();
       }
     } catch (final InterruptedException e) {
@@ -407,11 +347,10 @@ public final class JobStateManager {
    * @param unit of the timeout.
    * @return the final state of this job.
    */
-  public JobState waitUntilFinish(final long timeout,
-                                  final TimeUnit unit) {
+  public JobState.State waitUntilFinish(final long timeout, final TimeUnit unit) {
     finishLock.lock();
     try {
-      if (!checkJobTermination()) {
+      if (!isJobDone()) {
         if (!jobFinishedCondition.await(timeout, unit)) {
           LOG.warn("Timeout during waiting the finish of Job ID {}", jobId);
         }
@@ -425,28 +364,31 @@ public final class JobStateManager {
     return getJobState();
   }
 
+  public synchronized boolean isJobDone() {
+    return (getJobState() == JobState.State.COMPLETE || getJobState() == JobState.State.FAILED);
+  }
   public synchronized String getJobId() {
     return jobId;
   }
 
-  public synchronized JobState getJobState() {
-    return jobState;
-  }
-
-  public synchronized StageState getStageState(final String stageId) {
-    return idToStageStates.get(stageId);
+  public synchronized JobState.State getJobState() {
+    return (JobState.State) jobState.getStateMachine().getCurrentState();
   }
 
-  public synchronized Map<String, StageState> getIdToStageStates() {
-    return idToStageStates;
+  public synchronized StageState.State getStageState(final String stageId) {
+    return (StageState.State) idToStageStates.get(stageId).getStateMachine().getCurrentState();
   }
 
-  public synchronized TaskState getTaskState(final String taskId) {
-    return idToTaskStates.get(taskId);
+  public synchronized TaskState.State getTaskState(final String taskId) {
+    return (TaskState.State) idToTaskStates.get(taskId).getStateMachine().getCurrentState();
   }
 
-  public synchronized Map<String, TaskState> getIdToTaskStates() {
-    return idToTaskStates;
+  public synchronized int getTaskAttempt(final String taskId) {
+    if (taskIdToCurrentAttempt.containsKey(taskId)) {
+      return taskIdToCurrentAttempt.get(taskId);
+    } else {
+      throw new IllegalStateException("No mapping for this task's attemptIdx, an inconsistent state occurred.");
+    }
   }
 
   /**
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
index a424ecc..a7df05a 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
@@ -275,7 +275,7 @@ public final class RuntimeMaster {
         final ControlMessage.TaskStateChangedMsg taskStateChangedMsg
             = message.getTaskStateChangedMsg();
 
-        scheduler.onTaskStateChanged(taskStateChangedMsg.getExecutorId(),
+        scheduler.onTaskStateReportFromExecutor(taskStateChangedMsg.getExecutorId(),
             taskStateChangedMsg.getTaskId(),
             taskStateChangedMsg.getAttemptIdx(),
             convertTaskState(taskStateChangedMsg.getState()),
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java
index d06e1b5..3506063 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java
@@ -43,6 +43,7 @@ import java.util.stream.Collectors;
 import org.slf4j.Logger;
 
 import static edu.snu.nemo.runtime.common.state.TaskState.State.ON_HOLD;
+import static edu.snu.nemo.runtime.common.state.TaskState.State.READY;
 
 /**
  * (WARNING) Only a single dedicated thread should use the public methods of this class.
@@ -139,25 +140,25 @@ public final class BatchSingleJobScheduler implements Scheduler {
    * @param vertexPutOnHold the ID of vertex that is put on hold. It is null otherwise.
    */
   @Override
-  public void onTaskStateChanged(final String executorId,
-                                 final String taskId,
-                                 final int taskAttemptIndex,
-                                 final TaskState.State newState,
-                                 @Nullable final String vertexPutOnHold,
-                                 final TaskState.RecoverableFailureCause failureCause) {
-    final int currentTaskAttemptIndex = jobStateManager.getCurrentAttemptIndexForTask(taskId);
+  public void onTaskStateReportFromExecutor(final String executorId,
+                                            final String taskId,
+                                            final int taskAttemptIndex,
+                                            final TaskState.State newState,
+                                            @Nullable final String vertexPutOnHold,
+                                            final TaskState.RecoverableFailureCause failureCause) {
+    final int currentTaskAttemptIndex = jobStateManager.getTaskAttempt(taskId);
+
     if (taskAttemptIndex == currentTaskAttemptIndex) {
       // Do change state, as this notification is for the current task attempt.
+      jobStateManager.onTaskStateChanged(taskId, newState);
       switch (newState) {
         case COMPLETE:
-          jobStateManager.onTaskStateChanged(taskId, newState);
           onTaskExecutionComplete(executorId, taskId);
           break;
         case FAILED_RECOVERABLE:
-          onTaskExecutionFailedRecoverable(executorId, taskId, newState, failureCause);
+          onTaskExecutionFailedRecoverable(executorId, taskId, failureCause);
           break;
         case ON_HOLD:
-          jobStateManager.onTaskStateChanged(taskId, newState);
           onTaskExecutionOnHold(executorId, taskId, vertexPutOnHold);
           break;
         case FAILED_UNRECOVERABLE:
@@ -201,8 +202,8 @@ public final class BatchSingleJobScheduler implements Scheduler {
     });
 
     tasksToReExecute.forEach(failedTaskId -> {
-      final int attemptIndex = jobStateManager.getCurrentAttemptIndexForTask(failedTaskId);
-      onTaskStateChanged(executorId, failedTaskId, attemptIndex, TaskState.State.FAILED_RECOVERABLE,
+      final int attemptIndex = jobStateManager.getTaskAttempt(failedTaskId);
+      onTaskStateReportFromExecutor(executorId, failedTaskId, attemptIndex, TaskState.State.FAILED_RECOVERABLE,
           null, TaskState.RecoverableFailureCause.CONTAINER_FAILURE);
     });
 
@@ -289,8 +290,7 @@ public final class BatchSingleJobScheduler implements Scheduler {
 
     // We need to reschedule failed_recoverable stages.
     for (final Stage stageToCheck : currentScheduleGroup) {
-      final StageState.State stageState =
-          (StageState.State) jobStateManager.getStageState(stageToCheck.getId()).getStateMachine().getCurrentState();
+      final StageState.State stageState = jobStateManager.getStageState(stageToCheck.getId());
       switch (stageState) {
         case FAILED_RECOVERABLE:
           stagesToSchedule.add(stageToCheck);
@@ -315,10 +315,8 @@ public final class BatchSingleJobScheduler implements Scheduler {
         physicalPlan.getStageDAG().getTopologicalSort().stream().filter(stage -> {
           if (stage.getScheduleGroupIndex() == currentScheduleGroupIndex + 1) {
             final String stageId = stage.getId();
-            return jobStateManager.getStageState(stageId).getStateMachine().getCurrentState()
-                != StageState.State.EXECUTING
-                && jobStateManager.getStageState(stageId).getStateMachine().getCurrentState()
-                != StageState.State.COMPLETE;
+            return jobStateManager.getStageState(stageId) != StageState.State.EXECUTING
+                && jobStateManager.getStageState(stageId) != StageState.State.COMPLETE;
           }
           return false;
         }).collect(Collectors.toList());
@@ -346,15 +344,13 @@ public final class BatchSingleJobScheduler implements Scheduler {
     final List<StageEdge> stageOutgoingEdges =
         physicalPlan.getStageDAG().getOutgoingEdgesOf(stageToSchedule.getId());
 
-    final Enum stageState = jobStateManager.getStageState(stageToSchedule.getId()).getStateMachine().getCurrentState();
+    final StageState.State stageState = jobStateManager.getStageState(stageToSchedule.getId());
 
     final List<String> taskIdsToSchedule = new LinkedList<>();
     for (final String taskId : stageToSchedule.getTaskIds()) {
       // this happens when the belonging stage's other tasks have failed recoverable,
       // but this task's results are safe.
-      final TaskState.State taskState =
-          (TaskState.State)
-              jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState();
+      final TaskState.State taskState = jobStateManager.getTaskState(taskId);
 
       switch (taskState) {
         case COMPLETE:
@@ -372,7 +368,7 @@ public final class BatchSingleJobScheduler implements Scheduler {
           break;
         case FAILED_RECOVERABLE:
           LOG.info("Re-scheduling {} for failure recovery", taskId);
-          jobStateManager.onTaskStateChanged(taskId, TaskState.State.READY);
+          jobStateManager.onTaskStateChanged(taskId, READY);
           taskIdsToSchedule.add(taskId);
           break;
         case ON_HOLD:
@@ -382,13 +378,7 @@ public final class BatchSingleJobScheduler implements Scheduler {
           throw new SchedulingException(new Throwable("Detected a FAILED_UNRECOVERABLE Task"));
       }
     }
-    if (stageState == StageState.State.FAILED_RECOVERABLE) {
-      // The 'failed_recoverable' stage has been selected as the next stage to execute. Change its state back to 'ready'
-      jobStateManager.onStageStateChanged(stageToSchedule.getId(), StageState.State.READY);
-    }
 
-    // attemptIdx is only initialized/updated when we set the stage's state to executing
-    jobStateManager.onStageStateChanged(stageToSchedule.getId(), StageState.State.EXECUTING);
     LOG.info("Scheduling Stage {}", stageToSchedule.getId());
 
     // each readable and source task will be bounded in executor.
@@ -397,7 +387,7 @@ public final class BatchSingleJobScheduler implements Scheduler {
     taskIdsToSchedule.forEach(taskId -> {
       blockManagerMaster.onProducerTaskScheduled(taskId);
       final int taskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
-      final int attemptIdx = jobStateManager.getCurrentAttemptIndexForTask(taskId);
+      final int attemptIdx = jobStateManager.getTaskAttempt(taskId);
 
       LOG.debug("Enqueueing {}", taskId);
       pendingTaskCollection.add(new Task(
@@ -464,9 +454,9 @@ public final class BatchSingleJobScheduler implements Scheduler {
     }
 
     final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
-    if (jobStateManager.checkStageCompletion(stageIdForTaskUponCompletion)) {
+    if (jobStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE)) {
       // if the stage this task belongs to is complete,
-      if (!jobStateManager.checkJobTermination()) { // and if the job is not yet complete or failed,
+      if (!jobStateManager.isJobDone()) {
         scheduleNextStage(stageIdForTaskUponCompletion);
       }
     }
@@ -490,7 +480,7 @@ public final class BatchSingleJobScheduler implements Scheduler {
     final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
 
     final boolean stageComplete =
-        jobStateManager.checkStageCompletion(stageIdForTaskUponCompletion);
+        jobStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE);
 
     if (stageComplete) {
       // get optimization vertex from the task.
@@ -517,12 +507,10 @@ public final class BatchSingleJobScheduler implements Scheduler {
    * Action for after task execution has failed but it's recoverable.
    * @param executorId    the ID of the executor
    * @param taskId   the ID of the task
-   * @param newState      the state this situation
    * @param failureCause  the cause of failure
    */
   private void onTaskExecutionFailedRecoverable(final String executorId,
                                                 final String taskId,
-                                                final TaskState.State newState,
                                                 final TaskState.RecoverableFailureCause failureCause) {
     LOG.info("{} failed in {} by {}", taskId, executorId, failureCause);
     executorRegistry.updateExecutor(executorId, (executor, state) -> {
@@ -535,38 +523,12 @@ public final class BatchSingleJobScheduler implements Scheduler {
     switch (failureCause) {
       // Previous task must be re-executed, and incomplete tasks of the belonging stage must be rescheduled.
       case INPUT_READ_FAILURE:
-        jobStateManager.onTaskStateChanged(taskId, newState);
-        LOG.info("All tasks of {} will be made failed_recoverable.", stageId);
-        for (final Stage stage : physicalPlan.getStageDAG().getTopologicalSort()) {
-          if (stage.getId().equals(stageId)) {
-            LOG.info("Removing Tasks for {} before they are scheduled to an executor", stage.getId());
-            pendingTaskCollection.removeTasksAndDescendants(stage.getId());
-            stage.getTaskIds().forEach(dstTaskId -> {
-              if (jobStateManager.getTaskState(dstTaskId).getStateMachine().getCurrentState()
-                  != TaskState.State.COMPLETE) {
-                jobStateManager.onTaskStateChanged(dstTaskId, TaskState.State.FAILED_RECOVERABLE);
-                blockManagerMaster.onProducerTaskFailed(dstTaskId);
-              }
-            });
-            break;
-          }
-        }
-        // the stage this task belongs to has become failed recoverable.
-        // it is a good point to start searching for another stage to schedule.
-        scheduleNextStage(stageId);
-        break;
-      // The task executed successfully but there is something wrong with the output store.
+        // TODO #50: Carefully retry tasks in the scheduler
       case OUTPUT_WRITE_FAILURE:
-        jobStateManager.onTaskStateChanged(taskId, newState);
-        LOG.info("Only the failed task will be retried.");
-
-        // the stage this task belongs to has become failed recoverable.
-        // it is a good point to start searching for another stage to schedule.
         blockManagerMaster.onProducerTaskFailed(taskId);
         scheduleNextStage(stageId);
         break;
       case CONTAINER_FAILURE:
-        jobStateManager.onTaskStateChanged(taskId, newState);
         LOG.info("Only the failed task will be retried.");
         break;
       default:
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java
index c88e35a..ebf7937 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java
@@ -74,12 +74,12 @@ public interface Scheduler {
    * @param taskPutOnHold the ID of task that are put on hold. It is null otherwise.
    * @param failureCause for which the Task failed in the case of a recoverable failure.
    */
-  void onTaskStateChanged(String executorId,
-                          String taskId,
-                          int attemptIdx,
-                          TaskState.State newState,
-                          @Nullable String taskPutOnHold,
-                          TaskState.RecoverableFailureCause failureCause);
+  void onTaskStateReportFromExecutor(String executorId,
+                                     String taskId,
+                                     int attemptIdx,
+                                     TaskState.State newState,
+                                     @Nullable String taskPutOnHold,
+                                     TaskState.RecoverableFailureCause failureCause);
 
   /**
    * To be called when a job should be terminated.
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java
index a329c4b..243ef07 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java
@@ -17,7 +17,6 @@ package edu.snu.nemo.runtime.master.scheduler;
 
 import com.google.common.annotations.VisibleForTesting;
 import edu.snu.nemo.runtime.common.plan.Task;
-import edu.snu.nemo.runtime.common.state.JobState;
 import edu.snu.nemo.runtime.common.state.TaskState;
 import edu.snu.nemo.runtime.master.JobStateManager;
 import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter;
@@ -170,7 +169,7 @@ public final class SchedulerRunner {
         doScheduleStage();
       }
       jobStateManagers.values().forEach(jobStateManager -> {
-        if (jobStateManager.getJobState().getStateMachine().getCurrentState() == JobState.State.COMPLETE) {
+        if (jobStateManager.isJobDone()) {
           LOG.info("{} is complete.", jobStateManager.getJobId());
         } else {
           LOG.info("{} is incomplete.", jobStateManager.getJobId());
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java
index db5b1d4..4aa2b0a 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java
@@ -15,8 +15,6 @@
  */
 package edu.snu.nemo.runtime.master;
 
-import edu.snu.nemo.common.ir.edge.IREdge;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.conf.JobConf;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.message.MessageEnvironment;
@@ -25,12 +23,9 @@ import edu.snu.nemo.runtime.common.message.local.LocalMessageEnvironment;
 import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
 import edu.snu.nemo.runtime.common.plan.PhysicalPlanGenerator;
 import edu.snu.nemo.runtime.common.plan.Stage;
-import edu.snu.nemo.runtime.common.plan.StageEdge;
 import edu.snu.nemo.runtime.common.state.JobState;
 import edu.snu.nemo.runtime.common.state.StageState;
 import edu.snu.nemo.runtime.common.state.TaskState;
-import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.dag.DAGBuilder;
 import edu.snu.nemo.runtime.plangenerator.TestPlanGenerator;
 import org.apache.reef.tang.Injector;
 import org.apache.reef.tang.Tang;
@@ -41,10 +36,9 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.*;
+import java.util.stream.Collectors;
 
-import static junit.framework.TestCase.assertTrue;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.mockito.Mockito.mock;
@@ -56,14 +50,12 @@ import static org.mockito.Mockito.mock;
 @PrepareForTest(MetricMessageHandler.class)
 public final class JobStateManagerTest {
   private static final int MAX_SCHEDULE_ATTEMPT = 2;
-  private DAGBuilder<IRVertex, IREdge> irDAGBuilder;
   private BlockManagerMaster blockManagerMaster;
   private MetricMessageHandler metricMessageHandler;
   private PhysicalPlanGenerator physicalPlanGenerator;
 
   @Before
   public void setUp() throws Exception {
-    irDAGBuilder = new DAGBuilder<>();
     final LocalMessageDispatcher messageDispatcher = new LocalMessageDispatcher();
     final LocalMessageEnvironment messageEnvironment =
         new LocalMessageEnvironment(MessageEnvironment.MASTER_COMMUNICATION_ID, messageDispatcher);
@@ -92,23 +84,18 @@ public final class JobStateManagerTest {
 
     for (int stageIdx = 0; stageIdx < stageList.size(); stageIdx++) {
       final Stage stage = stageList.get(stageIdx);
-      jobStateManager.onStageStateChanged(stage.getId(), StageState.State.EXECUTING);
       final List<String> taskIds = stage.getTaskIds();
       taskIds.forEach(taskId -> {
         jobStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING);
         jobStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE);
         if (RuntimeIdGenerator.getIndexFromTaskId(taskId) == taskIds.size() - 1) {
-          assertTrue(jobStateManager.checkStageCompletion(stage.getId()));
+          assertEquals(StageState.State.COMPLETE, jobStateManager.getStageState(stage.getId()));
         }
       });
-      final Map<String, TaskState> taskStateMap = jobStateManager.getIdToTaskStates();
-      taskIds.forEach(taskId -> {
-        assertEquals(taskStateMap.get(taskId).getStateMachine().getCurrentState(),
-            TaskState.State.COMPLETE);
-      });
+      taskIds.forEach(taskId -> assertEquals(jobStateManager.getTaskState(taskId), TaskState.State.COMPLETE));
 
       if (stageIdx == stageList.size() - 1) {
-        assertEquals(jobStateManager.getJobState().getStateMachine().getCurrentState(), JobState.State.COMPLETE);
+        assertEquals(jobStateManager.getJobState(), JobState.State.COMPLETE);
       }
     }
   }
@@ -116,26 +103,28 @@ public final class JobStateManagerTest {
   /**
    * Test whether the methods waiting finish of job works properly.
    */
-  @Test(timeout = 1000)
-  public void testWaitUntilFinish() {
-    // Create a JobStateManager of an empty dag.
-    final DAG<IRVertex, IREdge> irDAG = irDAGBuilder.build();
-    final DAG<Stage, StageEdge> physicalDAG = irDAG.convert(physicalPlanGenerator);
-    final JobStateManager jobStateManager = new JobStateManager(
-        new PhysicalPlan("TestPlan", physicalDAG),
-        blockManagerMaster, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+  @Test(timeout = 2000)
+  public void testWaitUntilFinish() throws Exception {
+    final PhysicalPlan physicalPlan =
+        TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
+    final JobStateManager jobStateManager =
+        new JobStateManager(physicalPlan, blockManagerMaster, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
 
-    assertFalse(jobStateManager.checkJobTermination());
+    assertFalse(jobStateManager.isJobDone());
 
     // Wait for the job to finish and check the job state.
     // It have to return EXECUTING state after timeout.
-    JobState state = jobStateManager.waitUntilFinish(100, TimeUnit.MILLISECONDS);
-    assertEquals(state.getStateMachine().getCurrentState(), JobState.State.EXECUTING);
+    final JobState.State executingState = jobStateManager.waitUntilFinish(100, TimeUnit.MILLISECONDS);
+    assertEquals(JobState.State.EXECUTING, executingState);
 
     // Complete the job and check the result again.
     // It have to return COMPLETE.
-    jobStateManager.onJobStateChanged(JobState.State.COMPLETE);
-    state = jobStateManager.waitUntilFinish();
-    assertEquals(state.getStateMachine().getCurrentState(), JobState.State.COMPLETE);
+    final List<String> tasks = physicalPlan.getStageDAG().getTopologicalSort().stream()
+        .flatMap(stage -> stage.getTaskIds().stream())
+        .collect(Collectors.toList());
+    tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING));
+    tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE));
+    final JobState.State completedState = jobStateManager.waitUntilFinish();
+    assertEquals(JobState.State.COMPLETE, completedState);
   }
 }
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java
index 2d1ff2e..f0f5d19 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java
@@ -158,8 +158,7 @@ public final class BatchSingleJobSchedulerTest {
 
       LOG.debug("Checking that all stages of ScheduleGroup {} enter the executing state", scheduleGroupIdx);
       stages.forEach(stage -> {
-        while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState()
-            != StageState.State.EXECUTING) {
+        while (jobStateManager.getStageState(stage.getId()) != StageState.State.EXECUTING) {
 
         }
       });
@@ -171,9 +170,9 @@ public final class BatchSingleJobSchedulerTest {
     }
 
     LOG.debug("Waiting for job termination after sending stage completion events");
-    while (!jobStateManager.checkJobTermination()) {
+    while (!jobStateManager.isJobDone()) {
     }
-    assertTrue(jobStateManager.checkJobTermination());
+    assertTrue(jobStateManager.isJobDone());
   }
 
   private List<Stage> filterStagesWithAScheduleGroupIndex(
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
index 96c65eb..d2ec962 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/FaultToleranceTest.java
@@ -48,8 +48,6 @@ import java.util.concurrent.Executors;
 import java.util.function.Function;
 
 import static edu.snu.nemo.runtime.common.state.StageState.State.COMPLETE;
-import static edu.snu.nemo.runtime.common.state.StageState.State.EXECUTING;
-import static junit.framework.TestCase.assertFalse;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
@@ -104,7 +102,7 @@ public final class FaultToleranceTest {
   /**
    * Tests fault tolerance after a container removal.
    */
-  @Test(timeout=10000)
+  @Test(timeout=5000)
   public void testContainerRemoval() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -139,7 +137,7 @@ public final class FaultToleranceTest {
       if (stage.getScheduleGroupIndex() == 0 || stage.getScheduleGroupIndex() == 1) {
 
         // There are 3 executors, each of capacity 2, and there are 6 Tasks in ScheduleGroup 0 and 1.
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
         assertTrue(pendingTaskCollection.isEmpty());
         stage.getTaskIds().forEach(taskId ->
@@ -148,18 +146,22 @@ public final class FaultToleranceTest {
       } else if (stage.getScheduleGroupIndex() == 2) {
         scheduler.onExecutorRemoved("a3");
         // There are 2 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 2.
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
 
         // Due to round robin scheduling, "a2" is assured to have a running Task.
         scheduler.onExecutorRemoved("a2");
 
-        while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != EXECUTING) {
+        // Re-schedule
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+            executorRegistry, false);
 
-        }
-        assertEquals(jobStateManager.getAttemptCountForStage(stage.getId()), 2);
+        final Optional<Integer> maxTaskAttempt = stage.getTaskIds().stream()
+            .map(jobStateManager::getTaskAttempt).max(Integer::compareTo);
+        assertTrue(maxTaskAttempt.isPresent());
+        assertEquals(2, (int) maxTaskAttempt.get());
 
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
         assertTrue(pendingTaskCollection.isEmpty());
         stage.getTaskIds().forEach(taskId ->
@@ -168,7 +170,7 @@ public final class FaultToleranceTest {
       } else if (stage.getScheduleGroupIndex() == 3) {
         // There are 1 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 3.
         // Schedule only the first Task
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, true);
       } else {
         throw new RuntimeException(String.format("Unexpected ScheduleGroupIndex: %d",
@@ -180,7 +182,7 @@ public final class FaultToleranceTest {
   /**
    * Tests fault tolerance after an output write failure.
    */
-  @Test(timeout=10000)
+  @Test(timeout=5000)
   public void testOutputFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -213,7 +215,7 @@ public final class FaultToleranceTest {
       if (stage.getScheduleGroupIndex() == 0 || stage.getScheduleGroupIndex() == 1) {
 
         // There are 3 executors, each of capacity 2, and there are 6 Tasks in ScheduleGroup 0 and 1.
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
         assertTrue(pendingTaskCollection.isEmpty());
         stage.getTaskIds().forEach(taskId ->
@@ -221,7 +223,7 @@ public final class FaultToleranceTest {
                 taskId, TaskState.State.COMPLETE, 1));
       } else if (stage.getScheduleGroupIndex() == 2) {
         // There are 3 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 2.
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
         assertTrue(pendingTaskCollection.isEmpty());
         stage.getTaskIds().forEach(taskId ->
@@ -229,16 +231,18 @@ public final class FaultToleranceTest {
                 taskId, TaskState.State.FAILED_RECOVERABLE, 1,
                 TaskState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE));
 
-        while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != EXECUTING) {
+        // Re-schedule
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+            executorRegistry, false);
 
-        }
+        final Optional<Integer> maxTaskAttempt = stage.getTaskIds().stream()
+            .map(jobStateManager::getTaskAttempt).max(Integer::compareTo);
+        assertTrue(maxTaskAttempt.isPresent());
+        assertEquals(2, (int) maxTaskAttempt.get());
 
-        assertEquals(3, jobStateManager.getAttemptCountForStage(stage.getId()));
-        assertFalse(pendingTaskCollection.isEmpty());
-        stage.getTaskIds().forEach(taskId -> {
-          assertEquals(jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState(),
-              TaskState.State.READY);
-        });
+        assertTrue(pendingTaskCollection.isEmpty());
+        stage.getTaskIds().forEach(taskId ->
+            assertEquals(TaskState.State.EXECUTING, jobStateManager.getTaskState(taskId)));
       }
     }
   }
@@ -246,7 +250,7 @@ public final class FaultToleranceTest {
   /**
    * Tests fault tolerance after an input read failure.
    */
-  @Test(timeout=10000)
+  @Test(timeout=5000)
   public void testInputReadFailure() throws Exception {
     final ActiveContext activeContext = mock(ActiveContext.class);
     Mockito.doThrow(new RuntimeException()).when(activeContext).close();
@@ -279,7 +283,7 @@ public final class FaultToleranceTest {
       if (stage.getScheduleGroupIndex() == 0 || stage.getScheduleGroupIndex() == 1) {
 
         // There are 3 executors, each of capacity 2, and there are 6 Tasks in ScheduleGroup 0 and 1.
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
         assertTrue(pendingTaskCollection.isEmpty());
         stage.getTaskIds().forEach(taskId ->
@@ -287,7 +291,7 @@ public final class FaultToleranceTest {
                 taskId, TaskState.State.COMPLETE, 1));
       } else if (stage.getScheduleGroupIndex() == 2) {
         // There are 3 executors, each of capacity 2, and there are 2 Tasks in ScheduleGroup 2.
-        SchedulerTestUtil.mockSchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
             executorRegistry, false);
 
         stage.getTaskIds().forEach(taskId ->
@@ -295,15 +299,17 @@ public final class FaultToleranceTest {
                 taskId, TaskState.State.FAILED_RECOVERABLE, 1,
                 TaskState.RecoverableFailureCause.INPUT_READ_FAILURE));
 
-        while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != EXECUTING) {
+        // Re-schedule
+        SchedulerTestUtil.mockSchedulingBySchedulerRunner(pendingTaskCollection, schedulingPolicy, jobStateManager,
+            executorRegistry, false);
 
-        }
+        final Optional<Integer> maxTaskAttempt = stage.getTaskIds().stream()
+            .map(jobStateManager::getTaskAttempt).max(Integer::compareTo);
+        assertTrue(maxTaskAttempt.isPresent());
+        assertEquals(2, (int) maxTaskAttempt.get());
 
-        assertEquals(2, jobStateManager.getAttemptCountForStage(stage.getId()));
-        stage.getTaskIds().forEach(taskId -> {
-          assertEquals(jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState(),
-              TaskState.State.READY);
-        });
+        stage.getTaskIds().forEach(taskId ->
+            assertEquals(TaskState.State.EXECUTING, jobStateManager.getTaskState(taskId)));
       }
     }
   }
@@ -331,12 +337,12 @@ public final class FaultToleranceTest {
     final List<Stage> dagOf4Stages = plan.getStageDAG().getTopologicalSort();
 
     int executorIdIndex = 1;
-    float removalChance = 0.7f; // Out of 1.0
+    float removalChance = 0.5f; // Out of 1.0
     final Random random = new Random(0); // Deterministic seed.
 
     for (final Stage stage : dagOf4Stages) {
 
-      while (jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState() != COMPLETE) {
+      while (jobStateManager.getStageState(stage.getId()) != COMPLETE) {
         // By chance, remove or add executor
         if (isTrueByChance(random, removalChance)) {
           // REMOVE EXECUTOR
@@ -370,7 +376,7 @@ public final class FaultToleranceTest {
         }
       }
     }
-    assertTrue(jobStateManager.checkJobTermination());
+    assertTrue(jobStateManager.isJobDone());
   }
 
   private boolean isTrueByChance(final Random random, final float chance) {
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
index 55fbaef..cb15f7a 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
@@ -41,20 +41,20 @@ final class SchedulerTestUtil {
                             final int attemptIdx) {
     // Loop until the stage completes.
     while (true) {
-      final Enum stageState = jobStateManager.getStageState(stage.getId()).getStateMachine().getCurrentState();
+      final StageState.State stageState = jobStateManager.getStageState(stage.getId());
       if (StageState.State.COMPLETE == stageState) {
         // Stage has completed, so we break out of the loop.
         break;
       } else if (StageState.State.EXECUTING == stageState) {
         stage.getTaskIds().forEach(taskId -> {
-          final Enum tgState = jobStateManager.getTaskState(taskId).getStateMachine().getCurrentState();
-          if (TaskState.State.EXECUTING == tgState) {
+          final TaskState.State taskState = jobStateManager.getTaskState(taskId);
+          if (TaskState.State.EXECUTING == taskState) {
             sendTaskStateEventToScheduler(scheduler, executorRegistry, taskId,
                 TaskState.State.COMPLETE, attemptIdx, null);
-          } else if (TaskState.State.READY == tgState || TaskState.State.COMPLETE == tgState) {
+          } else if (TaskState.State.READY == taskState || TaskState.State.COMPLETE == taskState) {
             // Skip READY (try in the next loop and see if it becomes EXECUTING) and COMPLETE.
           } else {
-            throw new IllegalStateException(tgState.toString());
+            throw new IllegalStateException(taskState.toString());
           }
         });
       } else if (StageState.State.READY == stageState) {
@@ -88,7 +88,7 @@ final class SchedulerTestUtil {
         break;
       }
     }
-    scheduler.onTaskStateChanged(scheduledExecutor.getExecutorId(), taskId, attemptIdx,
+    scheduler.onTaskStateReportFromExecutor(scheduledExecutor.getExecutorId(), taskId, attemptIdx,
         newState, null, cause);
   }
 
@@ -100,17 +100,17 @@ final class SchedulerTestUtil {
     sendTaskStateEventToScheduler(scheduler, executorRegistry, taskId, newState, attemptIdx, null);
   }
 
-  static void mockSchedulerRunner(final PendingTaskCollection pendingTaskCollection,
-                                  final SchedulingPolicy schedulingPolicy,
-                                  final JobStateManager jobStateManager,
-                                  final ExecutorRegistry executorRegistry,
-                                  final boolean isPartialSchedule) {
+  static void mockSchedulingBySchedulerRunner(final PendingTaskCollection pendingTaskCollection,
+                                              final SchedulingPolicy schedulingPolicy,
+                                              final JobStateManager jobStateManager,
+                                              final ExecutorRegistry executorRegistry,
+                                              final boolean scheduleOnlyTheFirstStage) {
     final SchedulerRunner schedulerRunner =
         new SchedulerRunner(schedulingPolicy, pendingTaskCollection, executorRegistry);
     schedulerRunner.scheduleJob(jobStateManager);
     while (!pendingTaskCollection.isEmpty()) {
       schedulerRunner.doScheduleStage();
-      if (isPartialSchedule) {
+      if (scheduleOnlyTheFirstStage) {
         // Schedule only the first stage
         break;
       }
diff --git a/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java b/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java
index 0d60a5d..c962ad3 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java
@@ -22,6 +22,7 @@ import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.dag.DAGBuilder;
 import edu.snu.nemo.common.ir.edge.IREdge;
 import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.test.EmptyComponents;
 import edu.snu.nemo.conf.JobConf;
 import edu.snu.nemo.runtime.common.message.MessageEnvironment;
 import edu.snu.nemo.runtime.common.message.local.LocalMessageDispatcher;
@@ -31,9 +32,11 @@ import edu.snu.nemo.runtime.common.plan.PhysicalPlanGenerator;
 import edu.snu.nemo.runtime.common.plan.Stage;
 import edu.snu.nemo.runtime.common.plan.StageEdge;
 import edu.snu.nemo.runtime.common.state.JobState;
+import edu.snu.nemo.runtime.common.state.TaskState;
 import edu.snu.nemo.runtime.master.MetricMessageHandler;
 import edu.snu.nemo.runtime.master.BlockManagerMaster;
 import edu.snu.nemo.runtime.master.JobStateManager;
+import edu.snu.nemo.runtime.plangenerator.TestPlanGenerator;
 import org.apache.reef.tang.Injector;
 import org.apache.reef.tang.Tang;
 import org.junit.Test;
@@ -41,7 +44,9 @@ import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
+import java.util.List;
 import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.any;
@@ -68,22 +73,17 @@ public class ClientEndpointTest {
     // Wait for connection but not connected.
     assertEquals(clientEndpoint.waitUntilJobFinish(100, TimeUnit.MILLISECONDS), JobState.State.READY);
 
-    // Create a JobStateManager of an empty dag and create a DriverEndpoint with it.
-    final DAGBuilder<IRVertex, IREdge> irDagBuilder = new DAGBuilder<>();
-    final DAG<IRVertex, IREdge> irDAG = irDagBuilder.build();
-    final Injector injector = Tang.Factory.getTang().newInjector();
-    injector.bindVolatileParameter(JobConf.DAGDirectory.class, "");
-    final PhysicalPlanGenerator physicalPlanGenerator = injector.getInstance(PhysicalPlanGenerator.class);
-    final DAG<Stage, StageEdge> physicalDAG = irDAG.convert(physicalPlanGenerator);
-
+    // Create a JobStateManager of a dag and create a DriverEndpoint with it.
+    final PhysicalPlan physicalPlan =
+        TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
     final LocalMessageDispatcher messageDispatcher = new LocalMessageDispatcher();
     final LocalMessageEnvironment messageEnvironment =
         new LocalMessageEnvironment(MessageEnvironment.MASTER_COMMUNICATION_ID, messageDispatcher);
+    final Injector injector = Tang.Factory.getTang().newInjector();
     injector.bindVolatileInstance(MessageEnvironment.class, messageEnvironment);
     final BlockManagerMaster pmm = injector.getInstance(BlockManagerMaster.class);
-    final JobStateManager jobStateManager = new JobStateManager(
-        new PhysicalPlan("TestPlan", physicalDAG),
-        pmm, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+    final JobStateManager jobStateManager =
+        new JobStateManager(physicalPlan, pmm, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
 
     final DriverEndpoint driverEndpoint = new DriverEndpoint(jobStateManager, clientEndpoint);
 
@@ -94,8 +94,12 @@ public class ClientEndpointTest {
     assertEquals(clientEndpoint.waitUntilJobFinish(100, TimeUnit.MILLISECONDS), JobState.State.EXECUTING);
 
     // Check finish.
-    jobStateManager.onJobStateChanged(JobState.State.COMPLETE);
-    assertEquals(clientEndpoint.waitUntilJobFinish(), JobState.State.COMPLETE);
+    final List<String> tasks = physicalPlan.getStageDAG().getTopologicalSort().stream()
+        .flatMap(stage -> stage.getTaskIds().stream())
+        .collect(Collectors.toList());
+    tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.EXECUTING));
+    tasks.forEach(taskId -> jobStateManager.onTaskStateChanged(taskId, TaskState.State.COMPLETE));
+    assertEquals(JobState.State.COMPLETE, clientEndpoint.waitUntilJobFinish());
   }
 
   /**


Mime
View raw message