tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeag...@apache.org
Subject [01/25] TEZ-850. Recovery unit tests. (Jeff Zhang via hitesh)
Date Thu, 18 Sep 2014 19:49:46 GMT
Repository: tez
Updated Branches:
  refs/heads/TEZ-8 d6589d3ac -> 625450cf1


http://git-wip-us.apache.org/repos/asf/tez/blob/f65e65ae/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskRecovery.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskRecovery.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskRecovery.java
index cd9a1e8..c5153b6 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskRecovery.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestTaskRecovery.java
@@ -19,29 +19,28 @@
 package org.apache.tez.dag.app.dag.impl;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.api.records.Container;
 import org.apache.hadoop.yarn.api.records.ContainerId;
-import org.apache.hadoop.yarn.api.records.LocalResource;
 import org.apache.hadoop.yarn.api.records.NodeId;
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.event.DrainDispatcher;
 import org.apache.hadoop.yarn.event.EventHandler;
-import org.apache.hadoop.yarn.util.Clock;
 import org.apache.hadoop.yarn.util.SystemClock;
+import org.apache.tez.common.counters.TezCounters;
 import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.client.VertexStatus.State;
 import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
 import org.apache.tez.dag.api.oldrecords.TaskState;
 import org.apache.tez.dag.app.AppContext;
@@ -58,110 +57,532 @@ import org.apache.tez.dag.app.dag.event.TaskEvent;
 import org.apache.tez.dag.app.dag.event.TaskEventRecoverTask;
 import org.apache.tez.dag.app.dag.event.TaskEventType;
 import org.apache.tez.dag.app.dag.event.VertexEventType;
-import org.apache.tez.dag.app.rm.container.AMContainer;
 import org.apache.tez.dag.history.events.TaskAttemptFinishedEvent;
 import org.apache.tez.dag.history.events.TaskAttemptStartedEvent;
+import org.apache.tez.dag.history.events.TaskFinishedEvent;
 import org.apache.tez.dag.history.events.TaskStartedEvent;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.OutputCommitter;
+import org.apache.tez.runtime.api.OutputCommitterContext;
 import org.junit.Before;
 import org.junit.Test;
 
+import com.google.common.collect.Lists;
+
 public class TestTaskRecovery {
 
-  private static final Log LOG = LogFactory.getLog(TestTaskImpl.class);
+  private TaskImpl task;
+  private DrainDispatcher dispatcher;
 
-  private int taskCounter = 0;
   private int taskAttemptCounter = 0;
 
-  private Configuration conf;
-  private TaskAttemptListener taskAttemptListener;
-  private TaskHeartbeatHandler taskHeartbeatHandler;
-  private Credentials credentials;
-  private Clock clock;
-  private ApplicationId appId;
-  private TezDAGID dagId;
-  private TezVertexID vertexId;
+  private Configuration conf = new Configuration();
+  private AppContext mockAppContext;
+  private ApplicationId appId = ApplicationId.newInstance(
+      System.currentTimeMillis(), 1);
+  private TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+  private TezVertexID vertexId = TezVertexID.getInstance(dagId, 1);
   private Vertex vertex;
-  private AppContext appContext;
-  private Resource taskResource;
-  private Map<String, LocalResource> localResources;
-  private Map<String, String> environment;
-  private String javaOpts;
-  private boolean leafVertex;
-  private ContainerContext containerContext;
-  private ContainerId mockContainerId;
-  private Container mockContainer;
-  private AMContainer mockAMContainer;
-  private NodeId mockNodeId;
-
-  private TaskImpl task;
-  private DrainDispatcher dispatcher;
+  private String vertexName = "v1";
+  private long taskScheduledTime = 100L;
+  private long taskStartTime = taskScheduledTime + 100L;
+  private long taskFinishTime = taskStartTime + 100L;
+  private TaskAttemptEventHandler taEventHandler =
+      new TaskAttemptEventHandler();
 
-  class TaskEventHandler implements EventHandler<TaskEvent> {
+  private class TaskEventHandler implements EventHandler<TaskEvent> {
     @Override
     public void handle(TaskEvent event) {
       task.handle(event);
     }
   }
 
-  class TaskAttemptEventHandler implements EventHandler<TaskAttemptEvent> {
+  private class TaskAttemptEventHandler implements
+      EventHandler<TaskAttemptEvent> {
+
+    private List<TaskAttemptEvent> events = Lists.newArrayList();
+
     @Override
     public void handle(TaskAttemptEvent event) {
+      events.add(event);
       ((TaskAttemptImpl) task.getAttempt(event.getTaskAttemptID()))
           .handle(event);
     }
+
+    public List<TaskAttemptEvent> getEvents() {
+      return events;
+    }
   }
 
-  @Before
-  public void setUp() {
-    conf = new Configuration();
-    taskAttemptListener = mock(TaskAttemptListener.class);
-    taskHeartbeatHandler = mock(TaskHeartbeatHandler.class);
-    credentials = new Credentials();
-    clock = new SystemClock();
-    appId = ApplicationId.newInstance(System.currentTimeMillis(), 1);
-    dagId = TezDAGID.getInstance(appId, 1);
-    vertexId = TezVertexID.getInstance(dagId, 1);
-    vertex = mock(Vertex.class, RETURNS_DEEP_STUBS);
-    when(vertex.getProcessorDescriptor().getClassName()).thenReturn("");
-    appContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
-    mockContainerId = mock(ContainerId.class);
-    mockContainer = mock(Container.class);
-    mockAMContainer = mock(AMContainer.class);
-    mockNodeId = mock(NodeId.class);
-    when(mockContainer.getId()).thenReturn(mockContainerId);
-    when(mockContainer.getNodeId()).thenReturn(mockNodeId);
-    when(mockAMContainer.getContainer()).thenReturn(mockContainer);
-    when(appContext.getAllContainers().get(mockContainerId)).thenReturn(
-        mockAMContainer);
-    when(appContext.getCurrentDAG().getVertex(any(TezVertexID.class)))
-        .thenReturn(vertex);
-    when(vertex.getProcessorDescriptor().getClassName()).thenReturn("");
+  private class TestOutputCommitter extends OutputCommitter {
+
+    boolean recoverySupported = false;
+    boolean throwExceptionWhenRecovery = false;
+
+    public TestOutputCommitter(OutputCommitterContext committerContext,
+        boolean recoverySupported, boolean throwExceptionWhenRecovery) {
+      super(committerContext);
+      this.recoverySupported = recoverySupported;
+      this.throwExceptionWhenRecovery = throwExceptionWhenRecovery;
+    }
+
+    @Override
+    public void recoverTask(int taskIndex, int previousDAGAttempt)
+        throws Exception {
+      if (throwExceptionWhenRecovery) {
+        throw new Exception("fail recovery Task");
+      }
+    }
+
+    @Override
+    public boolean isTaskRecoverySupported() {
+      return recoverySupported;
+    }
+
+    @Override
+    public void initialize() throws Exception {
+
+    }
+
+    @Override
+    public void setupOutput() throws Exception {
+
+    }
+
+    @Override
+    public void commitOutput() throws Exception {
+
+    }
+
+    @Override
+    public void abortOutput(State finalState) throws Exception {
 
-    taskResource = Resource.newInstance(1024, 1);
-    localResources = new HashMap<String, LocalResource>();
-    environment = new HashMap<String, String>();
-    javaOpts = "";
-    leafVertex = false;
-    containerContext =
-        new ContainerContext(localResources, credentials, environment, javaOpts);
+    }
+
+  }
 
+  @Before
+  public void setUp() {
     dispatcher = new DrainDispatcher();
     dispatcher.register(DAGEventType.class, mock(EventHandler.class));
     dispatcher.register(VertexEventType.class, mock(EventHandler.class));
     dispatcher.register(TaskEventType.class, new TaskEventHandler());
-    dispatcher.register(TaskAttemptEventType.class,
-        new TaskAttemptEventHandler());
+    dispatcher.register(TaskAttemptEventType.class, taEventHandler);
     dispatcher.init(new Configuration());
     dispatcher.start();
 
+    vertex = mock(Vertex.class, RETURNS_DEEP_STUBS);
+    when(vertex.getProcessorDescriptor().getClassName()).thenReturn("");
+
+    mockAppContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
+    when(mockAppContext.getCurrentDAG().getVertex(any(TezVertexID.class)))
+        .thenReturn(vertex);
+
     task =
-        new TaskImpl(vertexId, 1, dispatcher.getEventHandler(), conf,
-            taskAttemptListener, clock, taskHeartbeatHandler, appContext,
-            leafVertex, taskResource, containerContext);
+        new TaskImpl(vertexId, 0, dispatcher.getEventHandler(),
+            new Configuration(), mock(TaskAttemptListener.class),
+            new SystemClock(), mock(TaskHeartbeatHandler.class),
+            mockAppContext, false, Resource.newInstance(1, 1),
+            mock(ContainerContext.class));
+
+    Map<String, OutputCommitter> committers =
+        new HashMap<String, OutputCommitter>();
+    committers.put("out1", new TestOutputCommitter(
+        mock(OutputCommitterContext.class), true, false));
+    when(task.getVertex().getOutputCommitters()).thenReturn(committers);
+  }
+
+  private void restoreFromTaskStartEvent() {
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskStartedEvent(task.getTaskId(),
+            vertexName, taskScheduledTime, taskStartTime));
+    assertEquals(TaskState.SCHEDULED, recoveredState);
+    assertEquals(0, task.finishedAttempts);
+    assertEquals(taskScheduledTime, task.scheduledTime);
+    assertEquals(0, task.getAttempts().size());
+  }
+
+  private void restoreFromFirstTaskAttemptStartEvent(TezTaskAttemptID taId) {
+    long taStartTime = taskStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptStartedEvent(taId, vertexName,
+            taStartTime, mock(ContainerId.class), mock(NodeId.class), "", ""));
+    assertEquals(TaskState.RUNNING, recoveredState);
+    assertEquals(0, task.finishedAttempts);
+    assertEquals(taskScheduledTime, task.scheduledTime);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(TaskAttemptStateInternal.NEW,
+        ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
+    assertEquals(1, task.numberUncompletedAttempts);
+  }
+
+  /**
+   * New -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_New() {
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.NEW, task.getInternalState());
+  }
+
+  /**
+   * -> restoreFromTaskFinishEvent ( no TaskStartEvent )
+   */
+  @Test
+  public void testRecovery_NoStartEvent() {
+    try {
+      task.restoreFromEvent(new TaskFinishedEvent(task.getTaskId(), vertexName,
+          taskStartTime, taskFinishTime, null, TaskState.SUCCEEDED, "",
+          new TezCounters()));
+      fail("Should fail due to no TaskStartEvent before TaskFinishEvent");
+    } catch (Throwable e) {
+      assertTrue(e.getMessage().contains(
+          "Finished Event seen but"
+              + " no Started Event was encountered earlier"));
+    }
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_Started() {
+    restoreFromTaskStartEvent();
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    // new task attempt is scheduled
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(0, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(null, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * RecoverTranstion
+   */
+  @Test
+  public void testRecovery_OneTAStarted() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(0, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(null, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_OneTAStarted_SUCCEEDED() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    long taStartTime = taskStartTime + 100L;
+    long taFinishTime = taStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.SUCCEEDED, "",
+            new TezCounters()));
+    assertEquals(TaskState.SUCCEEDED, recoveredState);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(taId, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.SUCCEEDED, task.getInternalState());
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(taId, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (FAILED) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_OneTAStarted_FAILED() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    long taStartTime = taskStartTime + 100L;
+    long taFinishTime = taStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.FAILED, "",
+            new TezCounters()));
+    assertEquals(TaskState.RUNNING, recoveredState);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(1, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    // new task attempt is scheduled
+    assertEquals(2, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(1, task.failedAttempts);
+    assertEquals(1, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (KILLED) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_OneTAStarted_KILLED() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    long taStartTime = taskStartTime + 100L;
+    long taFinishTime = taStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.KILLED, "",
+            new TezCounters()));
+    assertEquals(TaskState.RUNNING, recoveredState);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    // new task attempt is scheduled
+    assertEquals(2, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(1, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) ->
+   * restoreFromTaskFinishedEvent -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_OneTAStarted_SUCCEEDED_Finished() {
+
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    long taStartTime = taskStartTime + 100L;
+    long taFinishTime = taStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.SUCCEEDED, "",
+            new TezCounters()));
+    assertEquals(TaskState.SUCCEEDED, recoveredState);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(taId, task.successfulAttempt);
+
+    recoveredState =
+        task.restoreFromEvent(new TaskFinishedEvent(task.getTaskId(),
+            vertexName, taskStartTime, taskFinishTime, taId,
+            TaskState.SUCCEEDED, "", new TezCounters()));
+    assertEquals(TaskState.SUCCEEDED, recoveredState);
+    assertEquals(taId, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.SUCCEEDED, task.getInternalState());
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(taId, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_Commit_Failed_Recovery_Not_Supported() {
+    Map<String, OutputCommitter> committers =
+        new HashMap<String, OutputCommitter>();
+    committers.put("out1", new TestOutputCommitter(
+        mock(OutputCommitterContext.class), false, false));
+    when(task.getVertex().getOutputCommitters()).thenReturn(committers);
+
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    // restoreFromTaskAttemptFinishedEvent (SUCCEEDED)
+    long taStartTime = taskStartTime + 100L;
+    long taFinishTime = taStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.SUCCEEDED, "",
+            new TezCounters()));
+    assertEquals(TaskState.SUCCEEDED, recoveredState);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(taId, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    // new task attempt is scheduled
+    assertEquals(2, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(1, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_Commit_Failed_recover_fail() {
+    Map<String, OutputCommitter> committers =
+        new HashMap<String, OutputCommitter>();
+    committers.put("out1", new TestOutputCommitter(
+        mock(OutputCommitterContext.class), true, true));
+    when(task.getVertex().getOutputCommitters()).thenReturn(committers);
+
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+
+    // restoreFromTaskAttemptFinishedEvent (SUCCEEDED)
+    long taStartTime = taskStartTime + 100L;
+    long taFinishTime = taStartTime + 100L;
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.SUCCEEDED, "",
+            new TezCounters()));
+    assertEquals(TaskState.SUCCEEDED, recoveredState);
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(taId, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    // new task attempt is scheduled
+    assertEquals(2, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(1, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+  }
+
+  @Test
+  public void testRecovery_WithDesired_SUCCEEDED() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+    task.handle(new TaskEventRecoverTask(task.getTaskId(), TaskState.SUCCEEDED,
+        false));
+    assertEquals(TaskStateInternal.SUCCEEDED, task.getInternalState());
+    // no TA_Recovery event sent
+    assertEquals(0, taEventHandler.getEvents().size());
+  }
+
+  @Test
+  public void testRecovery_WithDesired_FAILED() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+    task.handle(new TaskEventRecoverTask(task.getTaskId(), TaskState.FAILED,
+        false));
+    assertEquals(TaskStateInternal.FAILED, task.getInternalState());
+    // no TA_Recovery event sent
+    assertEquals(0, taEventHandler.getEvents().size());
+  }
+
+  @Test
+  public void testRecovery_WithDesired_KILLED() {
+    restoreFromTaskStartEvent();
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    restoreFromFirstTaskAttemptStartEvent(taId);
+    task.handle(new TaskEventRecoverTask(task.getTaskId(), TaskState.KILLED,
+        false));
+    assertEquals(TaskStateInternal.KILLED, task.getInternalState());
+    // no TA_Recovery event sent
+    assertEquals(0, taEventHandler.getEvents().size());
+
+  }
+
+  /**
+   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
+   * restoreFromTaskAttemptFinishedEvent (KILLED) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_OneTAStarted_Killed() {
+    restoreFromTaskStartEvent();
+
+    long taStartTime = taskStartTime + 100L;
+    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+    TaskState recoveredState =
+        task.restoreFromEvent(new TaskAttemptStartedEvent(taId, vertexName,
+            taStartTime, mock(ContainerId.class), mock(NodeId.class), "", ""));
+    assertEquals(TaskState.RUNNING, recoveredState);
+    assertEquals(TaskAttemptStateInternal.NEW,
+        ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(0, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(1, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+
+    long taFinishTime = taStartTime + 100L;
+    recoveredState =
+        task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName,
+            taStartTime, taFinishTime, TaskAttemptState.KILLED, "",
+            new TezCounters()));
+    assertEquals(TaskState.RUNNING, recoveredState);
+    assertEquals(TaskAttemptStateInternal.NEW,
+        ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
+    assertEquals(1, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(0, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
+
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
+    // wait for Task send TA_RECOVER to TA and TA complete the RecoverTransition
+    dispatcher.await();
+    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
+    assertEquals(TaskAttemptStateInternal.KILLED,
+        ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
+    // new task attempt is scheduled
+    assertEquals(2, task.getAttempts().size());
+    assertEquals(1, task.finishedAttempts);
+    assertEquals(0, task.failedAttempts);
+    assertEquals(1, task.numberUncompletedAttempts);
+    assertEquals(null, task.successfulAttempt);
   }
 
   /**
@@ -170,24 +591,23 @@ public class TestTaskRecovery {
    * schedule a new task attempt.
    */
   @Test
-  public void testTaskRecovery1() {
-    TezTaskID lastTaskId = getNewTaskID();
-    TezTaskID taskId = getNewTaskID();
+  public void testTaskRecovery_MultipleAttempts1() {
     int maxFailedAttempts =
         conf.getInt(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
             TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
-    task.restoreFromEvent(new TaskStartedEvent(taskId, "v1", 0, 0));
+    restoreFromTaskStartEvent();
+
     for (int i = 0; i < maxFailedAttempts; ++i) {
-      TezTaskAttemptID attemptId = getNewTaskAttemptID(lastTaskId);
-      task.restoreFromEvent(new TaskAttemptStartedEvent(attemptId, "v1", 0,
-          mockContainerId, mockNodeId, "", ""));
-      task.restoreFromEvent(new TaskAttemptFinishedEvent(attemptId, "v1", 0, 0,
-          TaskAttemptState.KILLED, "", null));
+      TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+      task.restoreFromEvent(new TaskAttemptStartedEvent(taId, vertexName, 0L,
+          mock(ContainerId.class), mock(NodeId.class), "", ""));
+      task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName, 0,
+          0, TaskAttemptState.KILLED, "", null));
     }
     assertEquals(maxFailedAttempts, task.getAttempts().size());
     assertEquals(0, task.failedAttempts);
 
-    task.handle(new TaskEventRecoverTask(lastTaskId));
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
     // if the previous task attempt is killed, it should not been take into
     // account when checking whether exceed the max attempts
     assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
@@ -201,24 +621,23 @@ public class TestTaskRecovery {
    * failed_attempt is exceeded.
    */
   @Test
-  public void testTaskRecovery2() {
-    TezTaskID lastTaskId = getNewTaskID();
-    TezTaskID taskId = getNewTaskID();
+  public void testTaskRecovery_MultipleAttempts2() {
     int maxFailedAttempts =
         conf.getInt(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
             TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
-    task.restoreFromEvent(new TaskStartedEvent(taskId, "v1", 0, 0));
+    restoreFromTaskStartEvent();
+
     for (int i = 0; i < maxFailedAttempts; ++i) {
-      TezTaskAttemptID attemptId = getNewTaskAttemptID(lastTaskId);
-      task.restoreFromEvent(new TaskAttemptStartedEvent(attemptId, "v1", 0,
-          mockContainerId, mockNodeId, "", ""));
-      task.restoreFromEvent(new TaskAttemptFinishedEvent(attemptId, "v1", 0, 0,
-          TaskAttemptState.FAILED, "", null));
+      TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+      task.restoreFromEvent(new TaskAttemptStartedEvent(taId, vertexName, 0L,
+          mock(ContainerId.class), mock(NodeId.class), "", ""));
+      task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName, 0,
+          0, TaskAttemptState.FAILED, "", null));
     }
     assertEquals(maxFailedAttempts, task.getAttempts().size());
     assertEquals(maxFailedAttempts, task.failedAttempts);
 
-    task.handle(new TaskEventRecoverTask(lastTaskId));
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
     // it should transit to failed because of the failed task attempt in the
     // last application attempt.
     assertEquals(TaskStateInternal.FAILED, task.getInternalState());
@@ -232,34 +651,34 @@ public class TestTaskRecovery {
    * state and new task attempt is scheduled.
    */
   @Test
-  public void testTaskRecovery3() throws InterruptedException {
-    TezTaskID lastTaskId = getNewTaskID();
-    TezTaskID taskId = getNewTaskID();
+  public void testTaskRecovery_MultipleAttempts3() throws InterruptedException {
     int maxFailedAttempts =
         conf.getInt(TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
             TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
-    task.restoreFromEvent(new TaskStartedEvent(taskId, "v1", 0, 0));
+    restoreFromTaskStartEvent();
+
     for (int i = 0; i < maxFailedAttempts - 1; ++i) {
-      TezTaskAttemptID attemptId = getNewTaskAttemptID(lastTaskId);
-      task.restoreFromEvent(new TaskAttemptStartedEvent(attemptId, "v1", 0,
-          mockContainerId, mockNodeId, "", ""));
-      task.restoreFromEvent(new TaskAttemptFinishedEvent(attemptId, "v1", 0, 0,
-          TaskAttemptState.FAILED, "", null));
+      TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
+      task.restoreFromEvent(new TaskAttemptStartedEvent(taId, vertexName, 0L,
+          mock(ContainerId.class), mock(NodeId.class), "", ""));
+      task.restoreFromEvent(new TaskAttemptFinishedEvent(taId, vertexName, 0,
+          0, TaskAttemptState.FAILED, "", null));
     }
     assertEquals(maxFailedAttempts - 1, task.getAttempts().size());
     assertEquals(maxFailedAttempts - 1, task.failedAttempts);
 
-    TezTaskAttemptID newTaskAttemptId = getNewTaskAttemptID(lastTaskId);
+    TezTaskAttemptID newTaskAttemptId = getNewTaskAttemptID(task.getTaskId());
     TaskState recoveredState =
         task.restoreFromEvent(new TaskAttemptStartedEvent(newTaskAttemptId,
-            "v1", 0, mockContainerId, mockNodeId, "", ""));
+            vertexName, 0, mock(ContainerId.class), mock(NodeId.class), "", ""));
+
     assertEquals(TaskState.RUNNING, recoveredState);
     assertEquals(TaskAttemptStateInternal.NEW,
         ((TaskAttemptImpl) task.getAttempt(newTaskAttemptId))
             .getInternalState());
     assertEquals(maxFailedAttempts, task.getAttempts().size());
 
-    task.handle(new TaskEventRecoverTask(lastTaskId));
+    task.handle(new TaskEventRecoverTask(task.getTaskId()));
     // wait until task attempt receive the Recover event from task
     dispatcher.await();
 
@@ -273,12 +692,8 @@ public class TestTaskRecovery {
     assertEquals(maxFailedAttempts + 1, task.getAttempts().size());
   }
 
-  private TezTaskID getNewTaskID() {
-    TezTaskID taskID = TezTaskID.getInstance(vertexId, ++taskCounter);
-    return taskID;
-  }
-
   private TezTaskAttemptID getNewTaskAttemptID(TezTaskID taskId) {
     return TezTaskAttemptID.getInstance(taskId, taskAttemptCounter++);
   }
+
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/f65e65ae/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexRecovery.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexRecovery.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexRecovery.java
new file mode 100644
index 0000000..e2f189c
--- /dev/null
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexRecovery.java
@@ -0,0 +1,860 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.dag.impl;
+
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.event.DrainDispatcher;
+import org.apache.hadoop.yarn.event.EventHandler;
+import org.apache.hadoop.yarn.util.SystemClock;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.dag.api.records.DAGProtos;
+import org.apache.tez.dag.api.records.DAGProtos.DAGPlan;
+import org.apache.tez.dag.api.records.DAGProtos.EdgePlan;
+import org.apache.tez.dag.api.records.DAGProtos.PlanEdgeDataMovementType;
+import org.apache.tez.dag.api.records.DAGProtos.PlanEdgeDataSourceType;
+import org.apache.tez.dag.api.records.DAGProtos.PlanEdgeSchedulingType;
+import org.apache.tez.dag.api.records.DAGProtos.PlanTaskConfiguration;
+import org.apache.tez.dag.api.records.DAGProtos.PlanTaskLocationHint;
+import org.apache.tez.dag.api.records.DAGProtos.PlanVertexType;
+import org.apache.tez.dag.api.records.DAGProtos.TezEntityDescriptorProto;
+import org.apache.tez.dag.api.records.DAGProtos.VertexPlan;
+import org.apache.tez.dag.app.AppContext;
+import org.apache.tez.dag.app.TaskAttemptListener;
+import org.apache.tez.dag.app.TaskHeartbeatHandler;
+import org.apache.tez.dag.app.dag.VertexState;
+import org.apache.tez.dag.app.dag.event.DAGEvent;
+import org.apache.tez.dag.app.dag.event.DAGEventType;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEvent;
+import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
+import org.apache.tez.dag.app.dag.event.TaskEvent;
+import org.apache.tez.dag.app.dag.event.TaskEventRecoverTask;
+import org.apache.tez.dag.app.dag.event.TaskEventType;
+import org.apache.tez.dag.app.dag.event.VertexEvent;
+import org.apache.tez.dag.app.dag.event.VertexEventRecoverVertex;
+import org.apache.tez.dag.app.dag.event.VertexEventType;
+import org.apache.tez.dag.app.dag.impl.TestVertexImpl.CountingOutputCommitter;
+import org.apache.tez.dag.history.events.VertexDataMovementEventsGeneratedEvent;
+import org.apache.tez.dag.history.events.VertexFinishedEvent;
+import org.apache.tez.dag.history.events.VertexInitializedEvent;
+import org.apache.tez.dag.history.events.VertexStartedEvent;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.runtime.api.OutputCommitter;
+import org.apache.tez.runtime.api.events.InputDataInformationEvent;
+import org.apache.tez.runtime.api.impl.EventMetaData;
+import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType;
+import org.apache.tez.runtime.api.impl.TezEvent;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class TestVertexRecovery {
+
+  private static final Log LOG = LogFactory.getLog(TestVertexRecovery.class);
+
+  private DrainDispatcher dispatcher;
+
+  private AppContext mockAppContext;
+  private ApplicationId appId = ApplicationId.newInstance(
+      System.currentTimeMillis(), 1);
+  private DAGImpl dag;
+  private TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+  private String user = "user";
+
+  private long initRequestedTime = 100L;
+  private long initedTime = initRequestedTime + 100L;
+
+  /*
+   * v1 v2 \ / v3
+   */
+  private DAGPlan createDAGPlan() {
+    DAGPlan dag =
+        DAGPlan
+            .newBuilder()
+            .setName("testverteximpl")
+            .addVertex(
+                VertexPlan
+                    .newBuilder()
+                    .setName("vertex1")
+                    .setType(PlanVertexType.NORMAL)
+                    .addTaskLocationHint(
+                        PlanTaskLocationHint.newBuilder().addHost("host1")
+                            .addRack("rack1").build())
+                    .setTaskConfig(
+                        PlanTaskConfiguration.newBuilder().setNumTasks(1)
+                            .setVirtualCores(4).setMemoryMb(1024)
+                            .setJavaOpts("").setTaskModule("x1.y1").build())
+                    .addOutEdgeId("e1")
+                    .addOutputs(
+                        DAGProtos.RootInputLeafOutputProto
+                            .newBuilder()
+                            .setIODescriptor(
+                                TezEntityDescriptorProto.newBuilder()
+                                    .setClassName("output").build())
+                            .setName("outputx")
+                            .setControllerDescriptor(
+                                TezEntityDescriptorProto
+                                    .newBuilder()
+                                    .setClassName(
+                                        CountingOutputCommitter.class.getName())))
+                    .build())
+            .addVertex(
+                VertexPlan
+                    .newBuilder()
+                    .setName("vertex2")
+                    .setType(PlanVertexType.NORMAL)
+                    .addTaskLocationHint(
+                        PlanTaskLocationHint.newBuilder().addHost("host2")
+                            .addRack("rack2").build())
+                    .setTaskConfig(
+                        PlanTaskConfiguration.newBuilder().setNumTasks(2)
+                            .setVirtualCores(4).setMemoryMb(1024)
+                            .setJavaOpts("").setTaskModule("x2.y2").build())
+                    .addOutEdgeId("e2").build())
+            .addVertex(
+                VertexPlan
+                    .newBuilder()
+                    .setName("vertex3")
+                    .setType(PlanVertexType.NORMAL)
+                    .setProcessorDescriptor(
+                        TezEntityDescriptorProto.newBuilder().setClassName(
+                            "x3.y3"))
+                    .addTaskLocationHint(
+                        PlanTaskLocationHint.newBuilder().addHost("host3")
+                            .addRack("rack3").build())
+                    .setTaskConfig(
+                        PlanTaskConfiguration.newBuilder().setNumTasks(2)
+                            .setVirtualCores(4).setMemoryMb(1024)
+                            .setJavaOpts("foo").setTaskModule("x3.y3").build())
+                    .addInEdgeId("e1")
+                    .addInEdgeId("e2")
+                    .addOutputs(
+                        DAGProtos.RootInputLeafOutputProto
+                            .newBuilder()
+                            .setIODescriptor(
+                                TezEntityDescriptorProto.newBuilder()
+                                    .setClassName("output").build())
+                            .setName("outputx")
+                            .setControllerDescriptor(
+                                TezEntityDescriptorProto
+                                    .newBuilder()
+                                    .setClassName(
+                                        CountingOutputCommitter.class.getName())))
+                    .build()
+
+            )
+
+            .addEdge(
+                EdgePlan
+                    .newBuilder()
+                    .setEdgeDestination(
+                        TezEntityDescriptorProto.newBuilder().setClassName(
+                            "i3_v1"))
+                    .setInputVertexName("vertex1")
+                    .setEdgeSource(
+                        TezEntityDescriptorProto.newBuilder()
+                            .setClassName("o1"))
+                    .setOutputVertexName("vertex3")
+                    .setDataMovementType(
+                        PlanEdgeDataMovementType.SCATTER_GATHER).setId("e1")
+                    .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                    .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                    .build())
+            .addEdge(
+                EdgePlan
+                    .newBuilder()
+                    .setEdgeDestination(
+                        TezEntityDescriptorProto.newBuilder().setClassName(
+                            "i3_v2"))
+                    .setInputVertexName("vertex2")
+                    .setEdgeSource(
+                        TezEntityDescriptorProto.newBuilder()
+                            .setClassName("o2"))
+                    .setOutputVertexName("vertex3")
+                    .setDataMovementType(
+                        PlanEdgeDataMovementType.SCATTER_GATHER).setId("e2")
+                    .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                    .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                    .build()).build();
+
+    return dag;
+  }
+
+  class VertexEventHanlder implements EventHandler<VertexEvent> {
+
+    private List<VertexEvent> events = new ArrayList<VertexEvent>();
+
+    @Override
+    public void handle(VertexEvent event) {
+      events.add(event);
+      ((VertexImpl) dag.getVertex(event.getVertexId())).handle(event);
+    }
+
+    public List<VertexEvent> getEvents() {
+      return this.events;
+    }
+  }
+
+  class TaskEventHandler implements EventHandler<TaskEvent> {
+
+    private List<TaskEvent> events = new ArrayList<TaskEvent>();
+
+    @Override
+    public void handle(TaskEvent event) {
+      events.add(event);
+      ((TaskImpl) dag.getVertex(event.getTaskID().getVertexID()).getTask(
+          event.getTaskID())).handle(event);
+    }
+
+    public List<TaskEvent> getEvents() {
+      return events;
+    }
+  }
+
+  class TaskAttemptEventHandler implements EventHandler<TaskAttemptEvent> {
+
+    @Override
+    public void handle(TaskAttemptEvent event) {
+      // TezTaskID taskId = event.getTaskAttemptID().getTaskID();
+      // ((TaskAttemptImpl) vertex1.getTask(taskId).getAttempt(
+      // event.getTaskAttemptID())).handle(event);
+    }
+  }
+
+  private VertexEventHanlder vertexEventHandler;
+  private TaskEventHandler taskEventHandler;
+
+  @Before
+  public void setUp() throws IOException {
+
+    dispatcher = new DrainDispatcher();
+    dispatcher.register(DAGEventType.class, mock(EventHandler.class));
+    vertexEventHandler = new VertexEventHanlder();
+    dispatcher.register(VertexEventType.class, vertexEventHandler);
+    taskEventHandler = new TaskEventHandler();
+    dispatcher.register(TaskEventType.class, taskEventHandler);
+    dispatcher.register(TaskAttemptEventType.class,
+        new TaskAttemptEventHandler());
+    dispatcher.init(new Configuration());
+    dispatcher.start();
+
+    mockAppContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
+
+    DAGPlan dagPlan = createDAGPlan();
+    dag =
+        new DAGImpl(dagId, new Configuration(), dagPlan,
+            dispatcher.getEventHandler(), mock(TaskAttemptListener.class),
+            new Credentials(), new SystemClock(), user,
+            mock(TaskHeartbeatHandler.class), mockAppContext);
+    when(mockAppContext.getCurrentDAG()).thenReturn(dag);
+
+    dag.handle(new DAGEvent(dagId, DAGEventType.DAG_INIT));
+    LOG.info("finish setUp");
+  }
+
+  /**
+   * vertex1(New) -> StartRecoveryTransition(SUCCEEDED)
+   */
+  @Test
+  public void testRecovery_Desired_SUCCEEDED() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    VertexState recoveredState = vertex1.restoreFromEvent(new VertexInitializedEvent(vertex1.getVertexId(),
+        "vertex1", initRequestedTime, initedTime, vertex1.getTotalTasks(), "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.SUCCEEDED));
+    dispatcher.await();
+    assertEquals(VertexState.SUCCEEDED, vertex1.getState());
+    assertEquals(vertex1.numTasks, vertex1.succeededTaskCount);
+    assertEquals(vertex1.numTasks, vertex1.completedTaskCount);
+    // recover its task
+    assertTaskRecoveredEventSent(vertex1);
+
+    // vertex3 is still in NEW, when the desiredState is
+    // Completed State, each vertex recovery by itself, not depend on its parent
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    assertEquals(VertexState.NEW, vertex3.getState());
+    // no VertexEvent pass to downstream vertex
+    assertEquals(0, vertexEventHandler.getEvents().size());
+
+  }
+
+  /**
+   * vertex1(New) -> StartRecoveryTransition(FAILED)
+   */
+  @Test
+  public void testRecovery_Desired_FAILED() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    VertexState recoveredState = vertex1.restoreFromEvent(new VertexInitializedEvent(vertex1.getVertexId(),
+        "vertex1", initRequestedTime, initedTime, vertex1.getTotalTasks(), "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.FAILED));
+    dispatcher.await();
+    assertEquals(VertexState.FAILED, vertex1.getState());
+    assertEquals(vertex1.numTasks, vertex1.failedTaskCount);
+    assertEquals(0, vertex1.completedTaskCount);
+    // recover its task
+    assertTaskRecoveredEventSent(vertex1);
+
+    // vertex3 is still in NEW, when the desiredState is
+    // Completed State, each vertex recovery by itself, not depend on its parent
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    assertEquals(VertexState.NEW, vertex3.getState());
+    // no VertexEvent pass to downstream vertex
+    assertEquals(0, vertexEventHandler.getEvents().size());
+  }
+
+  /**
+   * vertex1(New) -> StartRecoveryTransition(KILLED)
+   */
+  @Test
+  public void testRecovery_Desired_KILLED() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    VertexState recoveredState = vertex1.restoreFromEvent(new VertexInitializedEvent(vertex1.getVertexId(),
+        "vertex1", initRequestedTime, initedTime, vertex1.getTotalTasks(), "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.KILLED));
+    dispatcher.await();
+    assertEquals(VertexState.KILLED, vertex1.getState());
+    assertEquals(vertex1.numTasks, vertex1.killedTaskCount);
+    assertEquals(0, vertex1.completedTaskCount);
+    // recover its task
+    assertTaskRecoveredEventSent(vertex1);
+
+    // vertex3 is still in NEW, when the desiredState is
+    // Completed State, each vertex recovery by itself, not depend on its parent
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    assertEquals(VertexState.NEW, vertex3.getState());
+    // no VertexEvent pass to downstream vertex
+    assertEquals(0, vertexEventHandler.getEvents().size());
+  }
+
+  /**
+   * vertex1(New) -> StartRecoveryTransition(ERROR)
+   */
+  @Test
+  public void testRecovery_Desired_ERROR() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    VertexState recoveredState = vertex1.restoreFromEvent(new VertexInitializedEvent(vertex1.getVertexId(),
+        "vertex1", initRequestedTime, initedTime, vertex1.getTotalTasks(), "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.ERROR));
+    dispatcher.await();
+    assertEquals(VertexState.ERROR, vertex1.getState());
+    assertEquals(vertex1.numTasks, vertex1.failedTaskCount);
+    assertEquals(0, vertex1.completedTaskCount);
+    // recover its task
+    assertTaskRecoveredEventSent(vertex1);
+
+    // vertex3 is still in NEW, when the desiredState is
+    // Completed State, each vertex recovery by itself, not depend on its parent
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    assertEquals(VertexState.NEW, vertex3.getState());
+    // no VertexEvent pass to downstream vertex
+    assertEquals(0, vertexEventHandler.getEvents().size());
+  }
+
+  private TezEvent createTezEvent() {
+    return new TezEvent(InputDataInformationEvent.createWithSerializedPayload(0, ByteBuffer.allocate(0)),
+        new EventMetaData(EventProducerConsumerType.INPUT, "vertex1", null,
+            null));
+  }
+
+  /**
+   * vertex1(New) -> restoreFromDataMovementEvent -> StartRecoveryTransition
+   */
+  @Test
+  public void testRecovery_New_Desired_RUNNING() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    VertexState recoveredState =
+        vertex1.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex1.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.NEW, recoveredState);
+    assertEquals(1, vertex1.recoveredEvents.size());
+
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+
+    // InputDataInformationEvent is removed
+    assertEquals(0, vertex1.recoveredEvents.size());
+    // V_INIT and V_START is sent
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+
+  }
+
+  private void assertTaskRecoveredEventSent(VertexImpl vertex) {
+    int sentNum = 0;
+    for (TaskEvent event : taskEventHandler.getEvents()) {
+      if (event.getType() == TaskEventType.T_RECOVER) {
+        TaskEventRecoverTask recoverEvent = (TaskEventRecoverTask)event;
+        if (recoverEvent.getTaskID().getVertexID().equals(vertex.getVertexId())){
+          sentNum++;
+        }
+      }
+    }
+    assertEquals("expect " + vertex.getTotalTasks()
+        + " TaskEventTaskRecover sent for vertex:" + vertex.getVertexId() +
+        "but actuall sent " + sentNum, vertex.getTotalTasks(), sentNum);
+  }
+
+  private void assertOutputCommitters(VertexImpl vertex){
+    assertTrue(vertex.getOutputCommitters() != null);
+    for (OutputCommitter c : vertex.getOutputCommitters().values()) {
+      CountingOutputCommitter committer = (CountingOutputCommitter) c;
+      assertEquals(0, committer.abortCounter);
+      assertEquals(0, committer.commitCounter);
+      assertEquals(1, committer.initCounter);
+      assertEquals(1, committer.setupCounter);
+    }
+  }
+
+  private void restoreFromInitializedEvent(VertexImpl vertex) {
+    long initTimeRequested = 100L;
+    long initedTime = initTimeRequested + 100L;
+    VertexState recoveredState =
+        vertex.restoreFromEvent(new VertexInitializedEvent(vertex
+            .getVertexId(), "vertex1", initTimeRequested, initedTime, vertex.getTotalTasks(),
+            "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    assertEquals(vertex.getTotalTasks(), vertex.getTasks().size());
+    assertEquals(initTimeRequested, vertex.initTimeRequested);
+    assertEquals(initedTime, vertex.initedTime);
+  }
+
+  /**
+   * restoreFromVertexInitializedEvent -> StartRecoveryTransition
+   */
+  @Test
+  public void testRecovery_Inited_Desired_RUNNING() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    restoreFromInitializedEvent(vertex1);
+
+    VertexState recoveredState =
+        vertex1.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex1.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.INITED, recoveredState);
+
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+
+    // InputDataInformationEvent is removed
+    assertEquals(0, vertex1.recoveredEvents.size());
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    // task recovered event is sent
+    assertTaskRecoveredEventSent(vertex1);
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+  }
+
+  /**
+   * restoreFromVertexInitializedEvent -> restoreFromVertexStartedEvent ->
+   * StartRecoveryTransition
+   */
+  @Test
+  public void testRecovery_Started_Desired_RUNNING() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    restoreFromInitializedEvent(vertex1);
+
+    long startTimeRequested = initedTime + 100L;
+    long startedTime = startTimeRequested + 100L;
+    VertexState recoveredState =
+        vertex1.restoreFromEvent(new VertexStartedEvent(vertex1.getVertexId(),
+            startTimeRequested, startedTime));
+    assertEquals(VertexState.RUNNING, recoveredState);
+    assertEquals(startTimeRequested, vertex1.startTimeRequested);
+    assertEquals(startedTime, vertex1.startedTime);
+
+    recoveredState =
+        vertex1.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex1.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.RUNNING, recoveredState);
+    assertEquals(1, vertex1.recoveredEvents.size());
+
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+
+    // InputDataInformationEvent is removed
+    assertEquals(0, vertex1.recoveredEvents.size());
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    // task recovered event is sent
+    assertTaskRecoveredEventSent(vertex1);
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+  }
+
+  /**
+   * restoreFromVertexInitializedEvent -> restoreFromVertexStartedEvent ->
+   * restoreFromVertexFinishedEvent -> StartRecoveryTransition
+   */
+  @Test
+  public void testRecovery_Finished_Desired_RUNNING() {
+    // v1: initFromInitializedEvent
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    restoreFromInitializedEvent(vertex1);
+
+    // v1: initFromStartedEvent
+    long startRequestedTime = initedTime + 100L;
+    long startTime = startRequestedTime + 100L;
+    VertexState recoveredState =
+        vertex1.restoreFromEvent(new VertexStartedEvent(vertex1.getVertexId(),
+            startRequestedTime, startTime));
+    assertEquals(VertexState.RUNNING, recoveredState);
+
+    // v1: initFromFinishedEvent
+    long finishTime = startTime + 100L;
+    recoveredState =
+        vertex1.restoreFromEvent(new VertexFinishedEvent(vertex1.getVertexId(),
+            "vertex1", initRequestedTime, initedTime, startRequestedTime,
+            startTime, finishTime, VertexState.SUCCEEDED, "",
+            new TezCounters(), new VertexStats()));
+    assertEquals(finishTime, vertex1.finishTime);
+    assertEquals(VertexState.SUCCEEDED, recoveredState);
+    assertEquals(false, vertex1.recoveryCommitInProgress);
+
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+
+    // InputDataInformationEvent is removed
+    assertEquals(0, vertex1.recoveredEvents.size());
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    // task recovered event is sent
+    assertTaskRecoveredEventSent(vertex1);
+
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+  }
+
+  /**
+   * vertex1 (New) -> StartRecoveryTransition <br>
+   * vertex2 (New) -> StartRecoveryTransition vertex3 (New) -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_RecoveringFromNew() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    assertEquals(1, vertex1.getTasks().size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    VertexState recoveredState =
+        vertex3.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex3.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.NEW, recoveredState);
+    assertEquals(1, vertex3.recoveredEvents.size());
+
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+
+    VertexImpl vertex2 = (VertexImpl) dag.getVertex("vertex2");
+    vertex2.handle(new VertexEventRecoverVertex(vertex2.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex2.getState());
+    // no OutputCommitter for vertex2
+    assertNull(vertex2.getOutputCommitters());
+
+    // v3 go to RUNNING because v1 and v2 both start
+    assertEquals(VertexState.RUNNING, vertex3.getState());
+    assertEquals(2, vertex3.numRecoveredSourceVertices);
+    assertEquals(2, vertex3.numInitedSourceVertices);
+    assertEquals(2, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+    // RootInputDataInformation is removed
+    assertEquals(0, vertex3.recoveredEvents.size());
+
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex3);
+
+  }
+
+  /**
+   * vertex1 (New) -> restoreFromInitialized -> StartRecoveryTransition<br>
+   * vertex2 (New) -> restoreFromInitialized -> StartRecoveryTransition<br>
+   * vertex3 (New) -> restoreFromVertexInitedEvent -> RecoverTransition<br>
+   */
+  @Test
+  public void testRecovery_RecoveringFromInited() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    restoreFromInitializedEvent(vertex1);
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    assertEquals(vertex1.getTotalTasks(), vertex1.getTasks().size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    VertexState recoveredState =
+        vertex3.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex3.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.NEW, recoveredState);
+    assertEquals(1, vertex3.recoveredEvents.size());
+    recoveredState =
+        vertex3.restoreFromEvent(new VertexInitializedEvent(vertex3
+            .getVertexId(), "vertex3", initRequestedTime, initedTime, 2, "",
+            null));
+    assertEquals(VertexState.INITED, recoveredState);
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+
+    VertexImpl vertex2 = (VertexImpl) dag.getVertex("vertex2");
+    restoreFromInitializedEvent(vertex2);
+    vertex2.handle(new VertexEventRecoverVertex(vertex2.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex2.getState());
+
+    // v3 go to RUNNING because v1 and v2 both start
+    assertEquals(VertexState.RUNNING, vertex3.getState());
+    assertEquals(2, vertex3.numRecoveredSourceVertices);
+    // numInitedSourceVertices is wrong but doesn't matter because v3 has
+    // already initialized
+    assertEquals(2, vertex3.numInitedSourceVertices);
+    assertEquals(2, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+    // RootInputDataInformation is removed
+    assertEquals(0, vertex3.recoveredEvents.size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex3);
+    // 1 for vertex1, 2 for vertex2, the second 2 for vertex3
+    assertTaskRecoveredEventSent(vertex1);
+    assertTaskRecoveredEventSent(vertex2);
+    assertTaskRecoveredEventSent(vertex3);
+  }
+
+  /**
+   * vertex1 (New) -> restoreFromInitialized -> restoreFromVertexStarted ->
+   * StartRecoveryTransition <br>
+   * vertex2 (New) -> restoreFromInitialized -> restoreFromVertexStarted -> StartRecoveryTransition <br>
+   * vertex3 (New) -> restoreFromInitialized -> restoreFromVertexStarted -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_RecoveringFromRunning() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    restoreFromInitializedEvent(vertex1);
+    VertexState recoveredState = vertex1.restoreFromEvent(new VertexStartedEvent(vertex1.getVertexId(),
+        initRequestedTime + 100L, initRequestedTime + 200L));
+    assertEquals(VertexState.RUNNING, recoveredState);
+
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    assertEquals(1, vertex1.getTasks().size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    recoveredState =
+        vertex3.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex3.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.NEW, recoveredState);
+    assertEquals(1, vertex3.recoveredEvents.size());
+    recoveredState =
+        vertex3.restoreFromEvent(new VertexInitializedEvent(vertex3
+            .getVertexId(), "vertex3", initRequestedTime, initedTime, vertex3.getTotalTasks(), "",
+            null));
+    assertEquals(VertexState.INITED, recoveredState);
+    recoveredState =
+        vertex3.restoreFromEvent(new VertexStartedEvent(vertex3.getVertexId(),
+            initRequestedTime + 100L, initRequestedTime + 200L));
+    assertEquals(VertexState.RUNNING, recoveredState);
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+
+    VertexImpl vertex2 = (VertexImpl) dag.getVertex("vertex2");
+    recoveredState = vertex2.restoreFromEvent(new VertexInitializedEvent(vertex2.getVertexId(),
+        "vertex2", initRequestedTime, initedTime, vertex2.getTotalTasks(), "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    recoveredState = vertex2.restoreFromEvent(new VertexStartedEvent(vertex2.getVertexId(),
+        initRequestedTime + 100L, initRequestedTime + 200L));
+    assertEquals(VertexState.RUNNING, recoveredState);
+    
+    vertex2.handle(new VertexEventRecoverVertex(vertex2.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex2.getState());
+
+    // v3 go to RUNNING because v1 and v2 both start
+    assertEquals(VertexState.RUNNING, vertex3.getState());
+    assertEquals(2, vertex3.numRecoveredSourceVertices);
+    assertEquals(2, vertex3.numInitedSourceVertices);
+    assertEquals(2, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+    // RootInputDataInformation is removed
+    assertEquals(0, vertex3.recoveredEvents.size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex3);
+
+    assertTaskRecoveredEventSent(vertex1);
+    assertTaskRecoveredEventSent(vertex2);
+    assertTaskRecoveredEventSent(vertex3);
+  }
+
+  /**
+   * vertex1 (New) -> restoreFromInitialized -> restoreFromVertexStarted ->
+   * restoreFromVertexFinished -> StartRecoveryTransition<br>
+   * vertex2 (New) -> restoreFromInitialized -> restoreFromVertexStarted ->
+   * restoreFromVertexFinished -> StartRecoveryTransition<br>
+   * vertex3 (New) -> restoreFromInitialized -> restoreFromVertexStarted -> RecoverTransition
+   */
+  @Test
+  public void testRecovery_RecoveringFromSUCCEEDED() {
+    VertexImpl vertex1 = (VertexImpl) dag.getVertex("vertex1");
+    restoreFromInitializedEvent(vertex1);
+    VertexState recoveredState = vertex1.restoreFromEvent(new VertexStartedEvent(vertex1.getVertexId(),
+        initRequestedTime + 100L, initRequestedTime + 200L));
+    assertEquals(VertexState.RUNNING, recoveredState);
+
+    recoveredState = vertex1.restoreFromEvent(new VertexFinishedEvent(vertex1.getVertexId(),
+        "vertex1", initRequestedTime, initedTime, initRequestedTime + 300L,
+        initRequestedTime + 400L, initRequestedTime + 500L,
+        VertexState.SUCCEEDED, "", new TezCounters(), new VertexStats()));
+    assertEquals(VertexState.SUCCEEDED, recoveredState);
+
+    vertex1.handle(new VertexEventRecoverVertex(vertex1.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex1.getState());
+    assertEquals(1, vertex1.getTasks().size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex1);
+
+    VertexImpl vertex3 = (VertexImpl) dag.getVertex("vertex3");
+    recoveredState =
+        vertex3.restoreFromEvent(new VertexDataMovementEventsGeneratedEvent(
+            vertex3.getVertexId(), Lists.newArrayList(createTezEvent())));
+    assertEquals(VertexState.NEW, recoveredState);
+    assertEquals(1, vertex3.recoveredEvents.size());
+    restoreFromInitializedEvent(vertex3);
+    recoveredState =
+        vertex3.restoreFromEvent(new VertexStartedEvent(vertex3.getVertexId(),
+            initRequestedTime + 100L, initRequestedTime + 200L));
+    assertEquals(VertexState.RUNNING, recoveredState);
+    // wait for recovery of vertex2
+    assertEquals(VertexState.RECOVERING, vertex3.getState());
+    assertEquals(1, vertex3.numRecoveredSourceVertices);
+    assertEquals(1, vertex3.numInitedSourceVertices);
+    assertEquals(1, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+
+    VertexImpl vertex2 = (VertexImpl) dag.getVertex("vertex2");
+    recoveredState = vertex2.restoreFromEvent(new VertexInitializedEvent(vertex2.getVertexId(),
+        "vertex2", initRequestedTime, initedTime, vertex2.getTotalTasks(), "", null));
+    assertEquals(VertexState.INITED, recoveredState);
+    recoveredState = vertex2.restoreFromEvent(new VertexStartedEvent(vertex2.getVertexId(),
+        initRequestedTime + 100L, initRequestedTime + 200L));
+    assertEquals(VertexState.RUNNING, recoveredState);
+    vertex2.handle(new VertexEventRecoverVertex(vertex2.getVertexId(),
+        VertexState.RUNNING));
+    dispatcher.await();
+    assertEquals(VertexState.RUNNING, vertex2.getState());
+
+    // v3 go to RUNNING because v1 and v2 both start
+    assertEquals(VertexState.RUNNING, vertex3.getState());
+    assertEquals(2, vertex3.numRecoveredSourceVertices);
+    assertEquals(2, vertex3.numInitedSourceVertices);
+    assertEquals(2, vertex3.numStartedSourceVertices);
+    assertEquals(1, vertex3.getDistanceFromRoot());
+    // RootInputDataInformation is removed
+    assertEquals(0, vertex3.recoveredEvents.size());
+    // verify OutputCommitter is initialized
+    assertOutputCommitters(vertex3);
+
+    assertTaskRecoveredEventSent(vertex1);
+    assertTaskRecoveredEventSent(vertex2);
+    assertTaskRecoveredEventSent(vertex3);
+  }
+}


Mime
View raw message