tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ss...@apache.org
Subject [1/7] tez git commit: TEZ-2708. Rename classes and variables post TEZ-2003 changes. (sseth)
Date Tue, 25 Aug 2015 23:48:48 GMT
Repository: tez
Updated Branches:
  refs/heads/master dc0ee0115 -> 8b278ea84


http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java
new file mode 100644
index 0000000..98b7baa
--- /dev/null
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java
@@ -0,0 +1,708 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.rm;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerStatus;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.api.records.Priority;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.event.Event;
+import org.apache.hadoop.yarn.event.EventHandler;
+import org.apache.tez.common.ContainerSignatureMatcher;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.NamedEntityDescriptor;
+import org.apache.tez.dag.api.TaskLocationHint;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezConstants;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.client.DAGClientServer;
+import org.apache.tez.dag.app.AppContext;
+import org.apache.tez.dag.app.ContainerContext;
+import org.apache.tez.dag.app.ServicePluginLifecycleAbstractService;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.impl.TaskAttemptImpl;
+import org.apache.tez.dag.app.dag.impl.TaskImpl;
+import org.apache.tez.dag.app.dag.impl.VertexImpl;
+import org.apache.tez.dag.app.rm.container.AMContainer;
+import org.apache.tez.dag.app.rm.container.AMContainerEventAssignTA;
+import org.apache.tez.dag.app.rm.container.AMContainerEventCompleted;
+import org.apache.tez.dag.app.rm.container.AMContainerEventType;
+import org.apache.tez.dag.app.rm.container.AMContainerMap;
+import org.apache.tez.dag.app.rm.container.AMContainerState;
+import org.apache.tez.dag.app.web.WebUIService;
+import org.apache.tez.dag.records.TaskAttemptTerminationCause;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.impl.TaskSpec;
+import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
+import org.apache.tez.serviceplugins.api.TaskScheduler;
+import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+@SuppressWarnings("rawtypes")
+public class TestTaskSchedulerManager {
+  
+  class TestEventHandler implements EventHandler{
+    List<Event> events = Lists.newLinkedList();
+    @Override
+    public void handle(Event event) {
+      events.add(event);
+    }
+  }
+  
+  class MockTaskSchedulerManager extends TaskSchedulerManager {
+
+    final AtomicBoolean notify = new AtomicBoolean(false);
+    
+    public MockTaskSchedulerManager(AppContext appContext,
+                                    DAGClientServer clientService, EventHandler eventHandler,
+                                    ContainerSignatureMatcher containerSignatureMatcher,
+                                    WebUIService webUI) {
+      super(appContext, clientService, eventHandler, containerSignatureMatcher, webUI,
+          Lists.newArrayList(new NamedEntityDescriptor("FakeDescriptor", null)), false);
+    }
+
+    @Override
+    protected void instantiateSchedulers(String host, int port, String trackingUrl,
+                                         AppContext appContext) {
+      taskSchedulers[0] = mockTaskScheduler;
+      taskSchedulerServiceWrappers[0] = new ServicePluginLifecycleAbstractService<>(taskSchedulers[0]);
+    }
+    
+    @Override
+    protected void notifyForTest() {
+      synchronized (notify) {
+        notify.set(true);
+        notify.notifyAll();
+      }
+    }
+    
+  }
+
+  AppContext mockAppContext;
+  DAGClientServer mockClientService;
+  TestEventHandler mockEventHandler;
+  ContainerSignatureMatcher mockSigMatcher;
+  MockTaskSchedulerManager schedulerHandler;
+  TaskScheduler mockTaskScheduler;
+  AMContainerMap mockAMContainerMap;
+  WebUIService mockWebUIService;
+
+  @Before
+  public void setup() {
+    mockAppContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
+    doReturn(new Configuration(false)).when(mockAppContext).getAMConf();
+    mockClientService = mock(DAGClientServer.class);
+    mockEventHandler = new TestEventHandler();
+    mockSigMatcher = mock(ContainerSignatureMatcher.class);
+    mockTaskScheduler = mock(TaskScheduler.class);
+    mockAMContainerMap = mock(AMContainerMap.class);
+    mockWebUIService = mock(WebUIService.class);
+    when(mockAppContext.getAllContainers()).thenReturn(mockAMContainerMap);
+    when(mockClientService.getBindAddress()).thenReturn(new InetSocketAddress(10000));
+    schedulerHandler = new MockTaskSchedulerManager(
+        mockAppContext, mockClientService, mockEventHandler, mockSigMatcher, mockWebUIService);
+  }
+
+  @Test(timeout = 5000)
+  public void testSimpleAllocate() throws Exception {
+    Configuration conf = new Configuration(false);
+    schedulerHandler.init(conf);
+    schedulerHandler.start();
+
+    TaskAttemptImpl mockTaskAttempt = mock(TaskAttemptImpl.class);
+    TezTaskAttemptID mockAttemptId = mock(TezTaskAttemptID.class);
+    when(mockAttemptId.getId()).thenReturn(0);
+    when(mockTaskAttempt.getID()).thenReturn(mockAttemptId);
+    Resource resource = Resource.newInstance(1024, 1);
+    ContainerContext containerContext =
+        new ContainerContext(new HashMap<String, LocalResource>(), new Credentials(),
+            new HashMap<String, String>(), "");
+    int priority = 10;
+    TaskLocationHint locHint = TaskLocationHint.createTaskLocationHint(new HashSet<String>(), null);
+
+    ContainerId mockCId = mock(ContainerId.class);
+    Container container = mock(Container.class);
+    when(container.getId()).thenReturn(mockCId);
+
+    AMContainer mockAMContainer = mock(AMContainer.class);
+    when(mockAMContainer.getContainerId()).thenReturn(mockCId);
+    when(mockAMContainer.getState()).thenReturn(AMContainerState.IDLE);
+
+    when(mockAMContainerMap.get(mockCId)).thenReturn(mockAMContainer);
+
+    AMSchedulerEventTALaunchRequest lr =
+        new AMSchedulerEventTALaunchRequest(mockAttemptId, resource, null, mockTaskAttempt, locHint,
+            priority, containerContext, 0, 0, 0);
+    schedulerHandler.taskAllocated(0, mockTaskAttempt, lr, container);
+    assertEquals(2, mockEventHandler.events.size());
+    assertTrue(mockEventHandler.events.get(1) instanceof AMContainerEventAssignTA);
+    AMContainerEventAssignTA assignEvent =
+        (AMContainerEventAssignTA) mockEventHandler.events.get(1);
+    assertEquals(priority, assignEvent.getPriority());
+    assertEquals(mockAttemptId, assignEvent.getTaskAttemptId());
+  }
+
+  @Test (timeout = 5000)
+  public void testTaskBasedAffinity() throws Exception {
+    Configuration conf = new Configuration(false);
+    schedulerHandler.init(conf);
+    schedulerHandler.start();
+
+    TaskAttemptImpl mockTaskAttempt = mock(TaskAttemptImpl.class);
+    TezTaskAttemptID taId = mock(TezTaskAttemptID.class);
+    String affVertexName = "srcVertex";
+    int affTaskIndex = 1;
+    TaskLocationHint locHint = TaskLocationHint.createTaskLocationHint(affVertexName, affTaskIndex);
+    VertexImpl affVertex = mock(VertexImpl.class);
+    TaskImpl affTask = mock(TaskImpl.class);
+    TaskAttemptImpl affAttempt = mock(TaskAttemptImpl.class);
+    ContainerId affCId = mock(ContainerId.class);
+    when(affVertex.getTotalTasks()).thenReturn(2);
+    when(affVertex.getTask(affTaskIndex)).thenReturn(affTask);
+    when(affTask.getSuccessfulAttempt()).thenReturn(affAttempt);
+    when(affAttempt.getAssignedContainerID()).thenReturn(affCId);
+    when(mockAppContext.getCurrentDAG().getVertex(affVertexName)).thenReturn(affVertex);
+    Resource resource = Resource.newInstance(100, 1);
+    AMSchedulerEventTALaunchRequest event = new AMSchedulerEventTALaunchRequest
+        (taId, resource, null, mockTaskAttempt, locHint, 3, null, 0, 0, 0);
+    schedulerHandler.notify.set(false);
+    schedulerHandler.handle(event);
+    synchronized (schedulerHandler.notify) {
+      while (!schedulerHandler.notify.get()) {
+        schedulerHandler.notify.wait();
+      }
+    }
+    
+    // verify mockTaskAttempt affinitized to expected affCId
+    verify(mockTaskScheduler, times(1)).allocateTask(mockTaskAttempt, resource, affCId,
+        Priority.newInstance(3), null, event);
+    
+    schedulerHandler.stop();
+    schedulerHandler.close();
+  }
+  
+  @Test (timeout = 5000)
+  public void testContainerPreempted() throws IOException {
+    Configuration conf = new Configuration(false);
+    schedulerHandler.init(conf);
+    schedulerHandler.start();
+    
+    String diagnostics = "Container preempted by RM.";
+    TaskAttemptImpl mockTask = mock(TaskAttemptImpl.class);
+    ContainerStatus mockStatus = mock(ContainerStatus.class);
+    ContainerId mockCId = mock(ContainerId.class);
+    AMContainer mockAMContainer = mock(AMContainer.class);
+    when(mockAMContainerMap.get(mockCId)).thenReturn(mockAMContainer);
+    when(mockAMContainer.getContainerId()).thenReturn(mockCId);
+    when(mockStatus.getContainerId()).thenReturn(mockCId);
+    when(mockStatus.getDiagnostics()).thenReturn(diagnostics);
+    when(mockStatus.getExitStatus()).thenReturn(ContainerExitStatus.PREEMPTED);
+    schedulerHandler.containerCompleted(0, mockTask, mockStatus);
+    assertEquals(1, mockEventHandler.events.size());
+    Event event = mockEventHandler.events.get(0);
+    assertEquals(AMContainerEventType.C_COMPLETED, event.getType());
+    AMContainerEventCompleted completedEvent = (AMContainerEventCompleted) event;
+    assertEquals(mockCId, completedEvent.getContainerId());
+    assertEquals("Container preempted externally. Container preempted by RM.",
+        completedEvent.getDiagnostics());
+    assertTrue(completedEvent.isPreempted());
+    assertEquals(TaskAttemptTerminationCause.EXTERNAL_PREEMPTION,
+        completedEvent.getTerminationCause());
+    Assert.assertFalse(completedEvent.isDiskFailed());
+
+    schedulerHandler.stop();
+    schedulerHandler.close();
+  }
+  
+  @Test (timeout = 5000)
+  public void testContainerInternalPreempted() throws IOException {
+    Configuration conf = new Configuration(false);
+    schedulerHandler.init(conf);
+    schedulerHandler.start();
+
+    AMContainer mockAmContainer = mock(AMContainer.class);
+    when(mockAmContainer.getTaskSchedulerIdentifier()).thenReturn(0);
+    when(mockAmContainer.getContainerLauncherIdentifier()).thenReturn(0);
+    when(mockAmContainer.getTaskCommunicatorIdentifier()).thenReturn(0);
+    ContainerId mockCId = mock(ContainerId.class);
+    verify(mockTaskScheduler, times(0)).deallocateContainer((ContainerId) any());
+    when(mockAMContainerMap.get(mockCId)).thenReturn(mockAmContainer);
+    schedulerHandler.preemptContainer(0, mockCId);
+    verify(mockTaskScheduler, times(1)).deallocateContainer(mockCId);
+    assertEquals(1, mockEventHandler.events.size());
+    Event event = mockEventHandler.events.get(0);
+    assertEquals(AMContainerEventType.C_COMPLETED, event.getType());
+    AMContainerEventCompleted completedEvent = (AMContainerEventCompleted) event;
+    assertEquals(mockCId, completedEvent.getContainerId());
+    assertEquals("Container preempted internally", completedEvent.getDiagnostics());
+    assertTrue(completedEvent.isPreempted());
+    Assert.assertFalse(completedEvent.isDiskFailed());
+    assertEquals(TaskAttemptTerminationCause.INTERNAL_PREEMPTION,
+        completedEvent.getTerminationCause());
+
+    schedulerHandler.stop();
+    schedulerHandler.close();
+  }
+  
+  @Test (timeout = 5000)
+  public void testContainerDiskFailed() throws IOException {
+    Configuration conf = new Configuration(false);
+    schedulerHandler.init(conf);
+    schedulerHandler.start();
+    
+    String diagnostics = "NM disk failed.";
+    TaskAttemptImpl mockTask = mock(TaskAttemptImpl.class);
+    ContainerStatus mockStatus = mock(ContainerStatus.class);
+    ContainerId mockCId = mock(ContainerId.class);
+    AMContainer mockAMContainer = mock(AMContainer.class);
+    when(mockAMContainerMap.get(mockCId)).thenReturn(mockAMContainer);
+    when(mockAMContainer.getContainerId()).thenReturn(mockCId);
+    when(mockStatus.getContainerId()).thenReturn(mockCId);
+    when(mockStatus.getDiagnostics()).thenReturn(diagnostics);
+    when(mockStatus.getExitStatus()).thenReturn(ContainerExitStatus.DISKS_FAILED);
+    schedulerHandler.containerCompleted(0, mockTask, mockStatus);
+    assertEquals(1, mockEventHandler.events.size());
+    Event event = mockEventHandler.events.get(0);
+    assertEquals(AMContainerEventType.C_COMPLETED, event.getType());
+    AMContainerEventCompleted completedEvent = (AMContainerEventCompleted) event;
+    assertEquals(mockCId, completedEvent.getContainerId());
+    assertEquals("Container disk failed. NM disk failed.",
+        completedEvent.getDiagnostics());
+    Assert.assertFalse(completedEvent.isPreempted());
+    assertTrue(completedEvent.isDiskFailed());
+    assertEquals(TaskAttemptTerminationCause.NODE_DISK_ERROR,
+        completedEvent.getTerminationCause());
+
+    schedulerHandler.stop();
+    schedulerHandler.close();
+  }
+
+  @Test (timeout = 5000)
+  public void testContainerExceededPMem() throws IOException {
+    Configuration conf = new Configuration(false);
+    schedulerHandler.init(conf);
+    schedulerHandler.start();
+
+    String diagnostics = "Exceeded Physical Memory";
+    TaskAttemptImpl mockTask = mock(TaskAttemptImpl.class);
+    ContainerStatus mockStatus = mock(ContainerStatus.class);
+    ContainerId mockCId = mock(ContainerId.class);
+    AMContainer mockAMContainer = mock(AMContainer.class);
+    when(mockAMContainerMap.get(mockCId)).thenReturn(mockAMContainer);
+    when(mockAMContainer.getContainerId()).thenReturn(mockCId);
+    when(mockStatus.getContainerId()).thenReturn(mockCId);
+    when(mockStatus.getDiagnostics()).thenReturn(diagnostics);
+    // use -104 rather than ContainerExitStatus.KILLED_EXCEEDED_PMEM because
+    // ContainerExitStatus.KILLED_EXCEEDED_PMEM is only available after hadoop-2.5
+    when(mockStatus.getExitStatus()).thenReturn(-104);
+    schedulerHandler.containerCompleted(0, mockTask, mockStatus);
+    assertEquals(1, mockEventHandler.events.size());
+    Event event = mockEventHandler.events.get(0);
+    assertEquals(AMContainerEventType.C_COMPLETED, event.getType());
+    AMContainerEventCompleted completedEvent = (AMContainerEventCompleted) event;
+    assertEquals(mockCId, completedEvent.getContainerId());
+    assertEquals("Container failed, exitCode=-104. Exceeded Physical Memory",
+        completedEvent.getDiagnostics());
+    Assert.assertFalse(completedEvent.isPreempted());
+    Assert.assertFalse(completedEvent.isDiskFailed());
+    assertEquals(TaskAttemptTerminationCause.CONTAINER_EXITED,
+        completedEvent.getTerminationCause());
+
+    schedulerHandler.stop();
+    schedulerHandler.close();
+  }
+
+  @Test (timeout = 5000)
+  public void testHistoryUrlConf() throws Exception {
+    Configuration conf = schedulerHandler.appContext.getAMConf();
+
+    // ensure history url is empty when timeline server is not the logging class
+    conf.set(TezConfiguration.TEZ_HISTORY_URL_BASE, "http://ui-host:9999");
+    assertTrue("".equals(schedulerHandler.getHistoryUrl()));
+
+    // ensure expansion of url happens
+    conf.set(TezConfiguration.TEZ_HISTORY_LOGGING_SERVICE_CLASS,
+        "org.apache.tez.dag.history.logging.ats.ATSHistoryLoggingService");
+    final ApplicationId mockApplicationId = mock(ApplicationId.class);
+    doReturn("TEST_APP_ID").when(mockApplicationId).toString();
+    doReturn(mockApplicationId).when(mockAppContext).getApplicationID();
+    assertTrue("http://ui-host:9999/#/tez-app/TEST_APP_ID"
+        .equals(schedulerHandler.getHistoryUrl()));
+
+    // ensure the trailing / in history url is handled
+    conf.set(TezConfiguration.TEZ_HISTORY_URL_BASE, "http://ui-host:9998/");
+    assertTrue("http://ui-host:9998/#/tez-app/TEST_APP_ID"
+        .equals(schedulerHandler.getHistoryUrl()));
+
+    // ensure missing scheme in history url is handled
+    conf.set(TezConfiguration.TEZ_HISTORY_URL_BASE, "ui-host:9998/");
+    Assert.assertTrue("http://ui-host:9998/#/tez-app/TEST_APP_ID"
+        .equals(schedulerHandler.getHistoryUrl()));
+
+    // handle bad template ex without begining /
+    conf.set(TezConfiguration.TEZ_AM_TEZ_UI_HISTORY_URL_TEMPLATE,
+        "__HISTORY_URL_BASE__#/somepath");
+    assertTrue("http://ui-host:9998/#/somepath"
+        .equals(schedulerHandler.getHistoryUrl()));
+
+    conf.set(TezConfiguration.TEZ_AM_TEZ_UI_HISTORY_URL_TEMPLATE,
+        "__HISTORY_URL_BASE__?viewPath=tez-app/__APPLICATION_ID__");
+    conf.set(TezConfiguration.TEZ_HISTORY_URL_BASE, "http://localhost/ui/tez");
+    assertTrue("http://localhost/ui/tez?viewPath=tez-app/TEST_APP_ID"
+        .equals(schedulerHandler.getHistoryUrl()));
+
+  }
+
+  @Test(timeout = 5000)
+  public void testNoSchedulerSpecified() throws IOException {
+    try {
+      new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler,
+          mockSigMatcher, mockWebUIService, null, false);
+      fail("Expecting an IllegalStateException with no schedulers specified");
+    } catch (IllegalArgumentException e) {
+    }
+  }
+
+  // Verified via statics
+  @Test(timeout = 5000)
+  public void testCustomTaskSchedulerSetup() throws IOException {
+    Configuration conf = new Configuration(false);
+    conf.set("testkey", "testval");
+    UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+    String customSchedulerName = "fakeScheduler";
+    List<NamedEntityDescriptor> taskSchedulers = new LinkedList<>();
+    ByteBuffer bb = ByteBuffer.allocate(4);
+    bb.putInt(0, 3);
+    UserPayload userPayload = UserPayload.create(bb);
+    taskSchedulers.add(
+        new NamedEntityDescriptor(customSchedulerName, FakeTaskScheduler.class.getName())
+            .setUserPayload(userPayload));
+    taskSchedulers.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+        .setUserPayload(defaultPayload));
+
+    TSEHForMultipleSchedulersTest tseh =
+        new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler,
+            mockSigMatcher, mockWebUIService, taskSchedulers, false);
+
+    tseh.init(conf);
+    tseh.start();
+
+    // Verify that the YARN task scheduler is installed by default
+    assertTrue(tseh.getYarnSchedulerCreated());
+    assertFalse(tseh.getUberSchedulerCreated());
+    assertEquals(2, tseh.getNumCreateInvocations());
+
+    // Verify the order of the schedulers
+    assertEquals(customSchedulerName, tseh.getTaskSchedulerName(0));
+    assertEquals(TezConstants.getTezYarnServicePluginName(), tseh.getTaskSchedulerName(1));
+
+    // Verify the payload setup for the custom task scheduler
+    assertNotNull(tseh.getTaskSchedulerContext(0));
+    assertEquals(bb, tseh.getTaskSchedulerContext(0).getInitialUserPayload().getPayload());
+
+    // Verify the payload on the yarn scheduler
+    assertNotNull(tseh.getTaskSchedulerContext(1));
+    Configuration parsed = TezUtils.createConfFromUserPayload(tseh.getTaskSchedulerContext(1).getInitialUserPayload());
+    assertEquals("testval", parsed.get("testkey"));
+  }
+
+  @Test(timeout = 5000)
+  public void testTaskSchedulerRouting() throws Exception {
+    Configuration conf = new Configuration(false);
+    UserPayload defaultPayload = TezUtils.createUserPayloadFromConf(conf);
+
+    String customSchedulerName = "fakeScheduler";
+    List<NamedEntityDescriptor> taskSchedulers = new LinkedList<>();
+    ByteBuffer bb = ByteBuffer.allocate(4);
+    bb.putInt(0, 3);
+    UserPayload userPayload = UserPayload.create(bb);
+    taskSchedulers.add(
+        new NamedEntityDescriptor(customSchedulerName, FakeTaskScheduler.class.getName())
+            .setUserPayload(userPayload));
+    taskSchedulers.add(new NamedEntityDescriptor(TezConstants.getTezYarnServicePluginName(), null)
+        .setUserPayload(defaultPayload));
+
+    TSEHForMultipleSchedulersTest tseh =
+        new TSEHForMultipleSchedulersTest(mockAppContext, mockClientService, mockEventHandler,
+            mockSigMatcher, mockWebUIService, taskSchedulers, false);
+
+    tseh.init(conf);
+    tseh.start();
+
+    // Verify that the YARN task scheduler is installed by default
+    assertTrue(tseh.getYarnSchedulerCreated());
+    assertFalse(tseh.getUberSchedulerCreated());
+    assertEquals(2, tseh.getNumCreateInvocations());
+
+    // Verify the order of the schedulers
+    assertEquals(customSchedulerName, tseh.getTaskSchedulerName(0));
+    assertEquals(TezConstants.getTezYarnServicePluginName(), tseh.getTaskSchedulerName(1));
+
+    verify(tseh.getTestTaskScheduler(0)).initialize();
+    verify(tseh.getTestTaskScheduler(0)).start();
+
+    ApplicationId appId = ApplicationId.newInstance(1000, 1);
+    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+    TezVertexID vertexID = TezVertexID.getInstance(dagId, 1);
+    TezTaskID taskId1 = TezTaskID.getInstance(vertexID, 1);
+    TezTaskAttemptID attemptId11 = TezTaskAttemptID.getInstance(taskId1, 1);
+    TezTaskID taskId2 = TezTaskID.getInstance(vertexID, 2);
+    TezTaskAttemptID attemptId21 = TezTaskAttemptID.getInstance(taskId2, 1);
+
+    Resource resource = Resource.newInstance(1024, 1);
+
+    TaskAttempt mockTaskAttempt1 = mock(TaskAttempt.class);
+    TaskAttempt mockTaskAttempt2 = mock(TaskAttempt.class);
+
+    AMSchedulerEventTALaunchRequest launchRequest1 =
+        new AMSchedulerEventTALaunchRequest(attemptId11, resource, mock(TaskSpec.class),
+            mockTaskAttempt1, mock(TaskLocationHint.class), 1, mock(ContainerContext.class), 0, 0,
+            0);
+
+    tseh.handle(launchRequest1);
+
+    verify(tseh.getTestTaskScheduler(0)).allocateTask(eq(mockTaskAttempt1), eq(resource),
+        any(String[].class), any(String[].class), any(Priority.class), any(Object.class),
+        eq(launchRequest1));
+
+    AMSchedulerEventTALaunchRequest launchRequest2 =
+        new AMSchedulerEventTALaunchRequest(attemptId21, resource, mock(TaskSpec.class),
+            mockTaskAttempt2, mock(TaskLocationHint.class), 1, mock(ContainerContext.class), 1, 0,
+            0);
+    tseh.handle(launchRequest2);
+    verify(tseh.getTestTaskScheduler(1)).allocateTask(eq(mockTaskAttempt2), eq(resource),
+        any(String[].class), any(String[].class), any(Priority.class), any(Object.class),
+        eq(launchRequest2));
+  }
+
+  private static class TSEHForMultipleSchedulersTest extends TaskSchedulerManager {
+
+    private final TaskScheduler yarnTaskScheduler;
+    private final TaskScheduler uberTaskScheduler;
+    private final AtomicBoolean uberSchedulerCreated = new AtomicBoolean(false);
+    private final AtomicBoolean yarnSchedulerCreated = new AtomicBoolean(false);
+    private final AtomicInteger numCreateInvocations = new AtomicInteger(0);
+    private final Set<Integer> seenSchedulers = new HashSet<>();
+    private final List<TaskSchedulerContext> taskSchedulerContexts = new LinkedList<>();
+    private final List<String> taskSchedulerNames = new LinkedList<>();
+    private final List<TaskScheduler> testTaskSchedulers = new LinkedList<>();
+
+    public TSEHForMultipleSchedulersTest(AppContext appContext,
+                                         DAGClientServer clientService,
+                                         EventHandler eventHandler,
+                                         ContainerSignatureMatcher containerSignatureMatcher,
+                                         WebUIService webUI,
+                                         List<NamedEntityDescriptor> schedulerDescriptors,
+                                         boolean isPureLocalMode) {
+      super(appContext, clientService, eventHandler, containerSignatureMatcher, webUI,
+          schedulerDescriptors, isPureLocalMode);
+      yarnTaskScheduler = mock(TaskScheduler.class);
+      uberTaskScheduler = mock(TaskScheduler.class);
+    }
+
+    @Override
+    TaskScheduler createTaskScheduler(String host, int port, String trackingUrl,
+                                      AppContext appContext,
+                                      NamedEntityDescriptor taskSchedulerDescriptor,
+                                      long customAppIdIdentifier,
+                                      int schedulerId) {
+
+      numCreateInvocations.incrementAndGet();
+      boolean added = seenSchedulers.add(schedulerId);
+      assertTrue("Cannot add multiple schedulers with the same schedulerId", added);
+      taskSchedulerNames.add(taskSchedulerDescriptor.getEntityName());
+      return super.createTaskScheduler(host, port, trackingUrl, appContext, taskSchedulerDescriptor,
+          customAppIdIdentifier, schedulerId);
+    }
+
+    @Override
+    TaskSchedulerContext wrapTaskSchedulerContext(TaskSchedulerContext rawContext) {
+      // Avoid wrapping in threads
+      return rawContext;
+    }
+
+    @Override
+    TaskScheduler createYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext, int schedulerId) {
+      taskSchedulerContexts.add(taskSchedulerContext);
+      testTaskSchedulers.add(yarnTaskScheduler);
+      yarnSchedulerCreated.set(true);
+      return yarnTaskScheduler;
+    }
+
+    @Override
+    TaskScheduler createUberTaskScheduler(TaskSchedulerContext taskSchedulerContext, int schedulerId) {
+      taskSchedulerContexts.add(taskSchedulerContext);
+      uberSchedulerCreated.set(true);
+      testTaskSchedulers.add(yarnTaskScheduler);
+      return uberTaskScheduler;
+    }
+
+    @Override
+    TaskScheduler createCustomTaskScheduler(TaskSchedulerContext taskSchedulerContext,
+                                            NamedEntityDescriptor taskSchedulerDescriptor, int schedulerId) {
+      taskSchedulerContexts.add(taskSchedulerContext);
+      TaskScheduler taskScheduler = spy(super.createCustomTaskScheduler(taskSchedulerContext, taskSchedulerDescriptor, schedulerId));
+      testTaskSchedulers.add(taskScheduler);
+      return taskScheduler;
+    }
+
+    @Override
+    // Inline handling of events.
+    public void handle(AMSchedulerEvent event) {
+      handleEvent(event);
+    }
+
+    public boolean getUberSchedulerCreated() {
+      return uberSchedulerCreated.get();
+    }
+
+    public boolean getYarnSchedulerCreated() {
+      return yarnSchedulerCreated.get();
+    }
+
+    public int getNumCreateInvocations() {
+      return numCreateInvocations.get();
+    }
+
+    public TaskSchedulerContext getTaskSchedulerContext(int schedulerId) {
+      return taskSchedulerContexts.get(schedulerId);
+    }
+
+    public String getTaskSchedulerName(int schedulerId) {
+      return taskSchedulerNames.get(schedulerId);
+    }
+
+    public TaskScheduler getTestTaskScheduler(int schedulerId) {
+      return testTaskSchedulers.get(schedulerId);
+    }
+  }
+
+  public static class FakeTaskScheduler extends TaskScheduler {
+
+    public FakeTaskScheduler(
+        TaskSchedulerContext taskSchedulerContext) {
+      super(taskSchedulerContext);
+    }
+
+    @Override
+    public Resource getAvailableResources() {
+      return null;
+    }
+
+    @Override
+    public int getClusterNodeCount() {
+      return 0;
+    }
+
+    @Override
+    public void dagComplete() {
+
+    }
+
+    @Override
+    public Resource getTotalResources() {
+      return null;
+    }
+
+    @Override
+    public void blacklistNode(NodeId nodeId) {
+
+    }
+
+    @Override
+    public void unblacklistNode(NodeId nodeId) {
+
+    }
+
+    @Override
+    public void allocateTask(Object task, Resource capability, String[] hosts, String[] racks,
+                             Priority priority, Object containerSignature, Object clientCookie) {
+
+    }
+
+    @Override
+    public void allocateTask(Object task, Resource capability, ContainerId containerId,
+                             Priority priority, Object containerSignature, Object clientCookie) {
+
+    }
+
+    @Override
+    public boolean deallocateTask(Object task, boolean taskSucceeded,
+                                  TaskAttemptEndReason endReason,
+                                  String diagnostics) {
+      return false;
+    }
+
+    @Override
+    public Object deallocateContainer(ContainerId containerId) {
+      return null;
+    }
+
+    @Override
+    public void setShouldUnregister() {
+
+    }
+
+    @Override
+    public boolean hasUnregistered() {
+      return false;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
index 13fa4c5..cc88f0d 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainer.java
@@ -69,14 +69,14 @@ import org.apache.tez.dag.api.TaskCommunicator;
 import org.apache.tez.dag.app.AppContext;
 import org.apache.tez.dag.app.ContainerHeartbeatHandler;
 import org.apache.tez.dag.app.ContainerContext;
-import org.apache.tez.dag.app.TaskAttemptListener;
+import org.apache.tez.dag.app.TaskCommunicatorManagerInterface;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventContainerTerminated;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventContainerTerminatedBySystem;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventContainerTerminating;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventNodeFailed;
 import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
 import org.apache.tez.dag.app.rm.AMSchedulerEventType;
-import org.apache.tez.dag.app.rm.NMCommunicatorEventType;
+import org.apache.tez.dag.app.rm.ContainerLauncherEventType;
 import org.apache.tez.dag.history.DAGHistoryEvent;
 import org.apache.tez.dag.history.HistoryEventHandler;
 import org.apache.tez.dag.records.TaskAttemptTerminationCause;
@@ -288,7 +288,7 @@ public class TestAMContainer {
     // Event to NM to stop the container.
     wc.verifyCountAndGetOutgoingEvents(1);
     assertTrue(wc.verifyCountAndGetOutgoingEvents(1).get(0).getType() ==
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST);
+        ContainerLauncherEventType.CONTAINER_STOP_REQUEST);
 
     wc.nmStopSent();
     wc.verifyState(AMContainerState.STOPPING);
@@ -323,7 +323,7 @@ public class TestAMContainer {
     // Event to NM to stop the container.
     wc.verifyCountAndGetOutgoingEvents(1);
     assertTrue(wc.verifyCountAndGetOutgoingEvents(1).get(0).getType() ==
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST);
+        ContainerLauncherEventType.CONTAINER_STOP_REQUEST);
 
     wc.nmStopFailed();
     wc.verifyState(AMContainerState.STOPPING);
@@ -366,7 +366,7 @@ public class TestAMContainer {
     // 1 for NM stop request. 2 TERMINATING to TaskAttempt.
     outgoingEvents = wc.verifyCountAndGetOutgoingEvents(3);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST,
+        ContainerLauncherEventType.CONTAINER_STOP_REQUEST,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING);
     assertTrue(wc.amContainer.isInErrorState());
@@ -405,7 +405,7 @@ public class TestAMContainer {
     // 1 for NM stop request. 2 TERMINATING to TaskAttempt.
     outgoingEvents = wc.verifyCountAndGetOutgoingEvents(3);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST,
+        ContainerLauncherEventType.CONTAINER_STOP_REQUEST,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING);
     assertTrue(wc.amContainer.isInErrorState());
@@ -443,7 +443,7 @@ public class TestAMContainer {
     outgoingEvents = wc.verifyCountAndGetOutgoingEvents(2);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING,
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST);
+        ContainerLauncherEventType.CONTAINER_STOP_REQUEST);
     // TODO Should this be an RM DE-ALLOCATE instead ?
 
     wc.containerCompleted();
@@ -478,7 +478,7 @@ public class TestAMContainer {
     outgoingEvents = wc.verifyCountAndGetOutgoingEvents(2);
     verifyUnOrderedOutgoingEventTypes(outgoingEvents,
         TaskAttemptEventType.TA_CONTAINER_TERMINATING,
-        NMCommunicatorEventType.CONTAINER_STOP_REQUEST);
+        ContainerLauncherEventType.CONTAINER_STOP_REQUEST);
     // TODO Should this be an RM DE-ALLOCATE instead ?
 
     wc.containerCompleted();
@@ -1194,7 +1194,7 @@ public class TestAMContainer {
     Priority priority;
     Container container;
     ContainerHeartbeatHandler chh;
-    TaskAttemptListener tal;
+    TaskCommunicatorManagerInterface tal;
 
     @SuppressWarnings("rawtypes")
     EventHandler eventHandler;
@@ -1226,7 +1226,7 @@ public class TestAMContainer {
 
       chh = mock(ContainerHeartbeatHandler.class);
 
-      tal = mock(TaskAttemptListener.class);
+      tal = mock(TaskCommunicatorManagerInterface.class);
       TaskCommunicator taskComm = mock(TaskCommunicator.class);
       doReturn(new InetSocketAddress("localhost", 0)).when(taskComm).getAddress();
       doReturn(taskComm).when(tal).getTaskCommunicator(0);
@@ -1440,7 +1440,7 @@ public class TestAMContainer {
     return lr;
   }
 
-  private void verifyUnregisterRunningContainer(TaskAttemptListener tal, ContainerId containerId,
+  private void verifyUnregisterRunningContainer(TaskCommunicatorManagerInterface tal, ContainerId containerId,
                                                 int taskCommId,
                                                 ContainerEndReason containerEndReason,
                                                 String diagContains) {
@@ -1455,7 +1455,7 @@ public class TestAMContainer {
     }
   }
 
-  private void verifyUnregisterTaskAttempt(TaskAttemptListener tal, TezTaskAttemptID taId,
+  private void verifyUnregisterTaskAttempt(TaskCommunicatorManagerInterface tal, TezTaskAttemptID taId,
                                            int taskCommId, TaskAttemptEndReason endReason,
                                            String diagContains) {
     ArgumentCaptor<String> argumentCaptor = ArgumentCaptor.forClass(String.class);

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainerMap.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainerMap.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainerMap.java
index dee4541..0230bb5 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainerMap.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/container/TestAMContainerMap.java
@@ -34,7 +34,7 @@ import org.apache.hadoop.yarn.api.records.Token;
 import org.apache.tez.dag.api.TaskCommunicator;
 import org.apache.tez.dag.app.AppContext;
 import org.apache.tez.dag.app.ContainerHeartbeatHandler;
-import org.apache.tez.dag.app.TaskAttemptListener;
+import org.apache.tez.dag.app.TaskCommunicatorManagerInterface;
 
 public class TestAMContainerMap {
 
@@ -42,8 +42,8 @@ public class TestAMContainerMap {
     return mock(ContainerHeartbeatHandler.class);
   }
 
-  private TaskAttemptListener mockTaskAttemptListener() {
-    TaskAttemptListener tal = mock(TaskAttemptListener.class);
+  private TaskCommunicatorManagerInterface mockTaskAttemptListener() {
+    TaskCommunicatorManagerInterface tal = mock(TaskCommunicatorManagerInterface.class);
     TaskCommunicator taskComm = mock(TaskCommunicator.class);
     doReturn(new InetSocketAddress("localhost", 21000)).when(taskComm).getAddress();
     doReturn(taskComm).when(tal).getTaskCommunicator(0);

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-dag/src/test/java/org/apache/tez/dag/app/rm/node/TestAMNodeTracker.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/node/TestAMNodeTracker.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/node/TestAMNodeTracker.java
index def80da..143fcbf 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/node/TestAMNodeTracker.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/node/TestAMNodeTracker.java
@@ -44,7 +44,7 @@ import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.app.AppContext;
 import org.apache.tez.dag.app.rm.AMSchedulerEventNodeBlacklistUpdate;
 import org.apache.tez.dag.app.rm.AMSchedulerEventType;
-import org.apache.tez.dag.app.rm.TaskSchedulerEventHandler;
+import org.apache.tez.dag.app.rm.TaskSchedulerManager;
 import org.apache.tez.dag.app.rm.container.AMContainerEventNodeFailed;
 import org.apache.tez.dag.app.rm.container.AMContainerEventType;
 import org.apache.tez.dag.app.rm.container.AMContainerMap;
@@ -186,11 +186,11 @@ public class TestAMNodeTracker {
     AMNodeTracker amNodeTracker = new AMNodeTracker(handler, appContext);
     doReturn(amNodeTracker).when(appContext).getNodeTracker();
     AMContainerMap amContainerMap = mock(AMContainerMap.class);
-    TaskSchedulerEventHandler taskSchedulerEventHandler =
-        mock(TaskSchedulerEventHandler.class);
+    TaskSchedulerManager taskSchedulerManager =
+        mock(TaskSchedulerManager.class);
     dispatcher.register(AMNodeEventType.class, amNodeTracker);
     dispatcher.register(AMContainerEventType.class, amContainerMap);
-    dispatcher.register(AMSchedulerEventType.class, taskSchedulerEventHandler);
+    dispatcher.register(AMSchedulerEventType.class, taskSchedulerManager);
     amNodeTracker.init(conf);
     amNodeTracker.start();
 
@@ -209,11 +209,11 @@ public class TestAMNodeTracker {
     AMNodeTracker amNodeTracker = new AMNodeTracker(handler, appContext);
     doReturn(amNodeTracker).when(appContext).getNodeTracker();
     AMContainerMap amContainerMap = mock(AMContainerMap.class);
-    TaskSchedulerEventHandler taskSchedulerEventHandler =
-        mock(TaskSchedulerEventHandler.class);
+    TaskSchedulerManager taskSchedulerManager =
+        mock(TaskSchedulerManager.class);
     dispatcher.register(AMNodeEventType.class, amNodeTracker);
     dispatcher.register(AMContainerEventType.class, amContainerMap);
-    dispatcher.register(AMSchedulerEventType.class, taskSchedulerEventHandler);
+    dispatcher.register(AMSchedulerEventType.class, taskSchedulerManager);
     amNodeTracker.init(conf);
     amNodeTracker.start();
 
@@ -232,11 +232,11 @@ public class TestAMNodeTracker {
     AMNodeTracker amNodeTracker = new AMNodeTracker(handler, appContext);
     doReturn(amNodeTracker).when(appContext).getNodeTracker();
     AMContainerMap amContainerMap = mock(AMContainerMap.class);
-    TaskSchedulerEventHandler taskSchedulerEventHandler =
-        mock(TaskSchedulerEventHandler.class);
+    TaskSchedulerManager taskSchedulerManager =
+        mock(TaskSchedulerManager.class);
     dispatcher.register(AMNodeEventType.class, amNodeTracker);
     dispatcher.register(AMContainerEventType.class, amContainerMap);
-    dispatcher.register(AMSchedulerEventType.class, taskSchedulerEventHandler);
+    dispatcher.register(AMSchedulerEventType.class, taskSchedulerManager);
     amNodeTracker.init(conf);
     amNodeTracker.start();
 
@@ -260,11 +260,11 @@ public class TestAMNodeTracker {
     AMNodeTracker amNodeTracker = new AMNodeTracker(handler, appContext);
     doReturn(amNodeTracker).when(appContext).getNodeTracker();
     AMContainerMap amContainerMap = mock(AMContainerMap.class);
-    TaskSchedulerEventHandler taskSchedulerEventHandler =
-        mock(TaskSchedulerEventHandler.class);
+    TaskSchedulerManager taskSchedulerManager =
+        mock(TaskSchedulerManager.class);
     dispatcher.register(AMNodeEventType.class, amNodeTracker);
     dispatcher.register(AMContainerEventType.class, amContainerMap);
-    dispatcher.register(AMSchedulerEventType.class, taskSchedulerEventHandler);
+    dispatcher.register(AMSchedulerEventType.class, taskSchedulerManager);
     amNodeTracker.init(conf);
     amNodeTracker.start();
     try {
@@ -283,11 +283,11 @@ public class TestAMNodeTracker {
     AMNodeTracker amNodeTracker = new AMNodeTracker(handler, appContext);
     doReturn(amNodeTracker).when(appContext).getNodeTracker();
     AMContainerMap amContainerMap = mock(AMContainerMap.class);
-    TaskSchedulerEventHandler taskSchedulerEventHandler =
-        mock(TaskSchedulerEventHandler.class);
+    TaskSchedulerManager taskSchedulerManager =
+        mock(TaskSchedulerManager.class);
     dispatcher.register(AMNodeEventType.class, amNodeTracker);
     dispatcher.register(AMContainerEventType.class, amContainerMap);
-    dispatcher.register(AMSchedulerEventType.class, taskSchedulerEventHandler);
+    dispatcher.register(AMSchedulerEventType.class, taskSchedulerManager);
     amNodeTracker.init(conf);
     amNodeTracker.start();
     try {
@@ -306,11 +306,11 @@ public class TestAMNodeTracker {
     AMNodeTracker amNodeTracker = new AMNodeTracker(handler, appContext);
     doReturn(amNodeTracker).when(appContext).getNodeTracker();
     AMContainerMap amContainerMap = mock(AMContainerMap.class);
-    TaskSchedulerEventHandler taskSchedulerEventHandler =
-        mock(TaskSchedulerEventHandler.class);
+    TaskSchedulerManager taskSchedulerManager =
+        mock(TaskSchedulerManager.class);
     dispatcher.register(AMNodeEventType.class, amNodeTracker);
     dispatcher.register(AMContainerEventType.class, amContainerMap);
-    dispatcher.register(AMSchedulerEventType.class, taskSchedulerEventHandler);
+    dispatcher.register(AMSchedulerEventType.class, taskSchedulerManager);
     amNodeTracker.init(conf);
     amNodeTracker.start();
     try {

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java
----------------------------------------------------------------------
diff --git a/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java
index 3b4c768..fb4c08f 100644
--- a/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java
+++ b/tez-ext-service-tests/src/test/java/org/apache/tez/service/impl/ContainerRunnerImpl.java
@@ -60,7 +60,8 @@ import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.runtime.api.impl.TaskSpec;
 import org.apache.tez.runtime.task.TaskReporter;
-import org.apache.tez.runtime.task.TezTaskRunner;
+import org.apache.tez.runtime.task.TaskRunner2Result;
+import org.apache.tez.runtime.task.TezTaskRunner2;
 import org.apache.tez.service.ContainerRunner;
 import org.apache.tez.dag.api.TezConstants;
 import org.apache.tez.runtime.api.ExecutionContext;
@@ -378,7 +379,7 @@ public class ContainerRunnerImpl extends AbstractService implements ContainerRun
     private final Credentials credentials;
     private final long memoryAvailable;
     private final ListeningExecutorService executor;
-    private volatile TezTaskRunner taskRunner;
+    private volatile TezTaskRunner2 taskRunner;
     private volatile TaskReporter taskReporter;
     private TezTaskUmbilicalProtocol umbilical;
 
@@ -443,7 +444,7 @@ public class ContainerRunnerImpl extends AbstractService implements ContainerRun
           new AtomicLong(0),
           request.getContainerIdString());
 
-      taskRunner = new TezTaskRunner(conf, taskUgi, localDirs,
+      taskRunner = new TezTaskRunner2(conf, taskUgi, localDirs,
           ProtoConverters.getTaskSpecfromProto(request.getTaskSpec()),
           request.getAppAttemptNumber(),
           serviceConsumerMetadata, envMap, startedInputsMap, taskReporter, executor, objectRegistry,
@@ -452,18 +453,20 @@ public class ContainerRunnerImpl extends AbstractService implements ContainerRun
 
       boolean shouldDie;
       try {
-        shouldDie = !taskRunner.run();
+        TaskRunner2Result result = taskRunner.run();
+        LOG.info("TaskRunner2Result: {}", result);
+        shouldDie = result.isContainerShutdownRequested();
         if (shouldDie) {
           LOG.info("Got a shouldDie notification via heartbeats. Shutting down");
           return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.SUCCESS, null,
               "Asked to die by the AM");
         }
-      } catch (IOException e) {
-        return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE,
-            e, "TaskExecutionFailure: " + e.getMessage());
-      } catch (TezException e) {
-        return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE,
-            e, "TaskExecutionFailure: " + e.getMessage());
+        if (result.getError() != null) {
+          Throwable e = result.getError();
+          return new ContainerExecutionResult(
+              ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE,
+              e, "TaskExecutionFailure: " + e.getMessage());
+        }
       } finally {
         FileSystem.closeAllForUGI(taskUgi);
       }

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TezTaskRunner.java
----------------------------------------------------------------------
diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TezTaskRunner.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TezTaskRunner.java
deleted file mode 100644
index aebf6a9..0000000
--- a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/task/TezTaskRunner.java
+++ /dev/null
@@ -1,451 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.tez.runtime.task;
-
-import java.io.IOException;
-import java.lang.reflect.UndeclaredThrowableException;
-import java.nio.ByteBuffer;
-import java.security.PrivilegedExceptionAction;
-import java.util.Collection;
-import java.util.Map;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.locks.Condition;
-import java.util.concurrent.locks.Lock;
-import java.util.concurrent.locks.ReentrantLock;
-
-import com.google.common.base.Throwables;
-import org.apache.commons.lang.exception.ExceptionUtils;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FSError;
-import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.tez.common.CallableWithNdc;
-import org.apache.tez.dag.api.TezException;
-import org.apache.tez.dag.records.TezTaskAttemptID;
-import org.apache.tez.runtime.LogicalIOProcessorRuntimeTask;
-import org.apache.tez.runtime.RuntimeTask;
-import org.apache.tez.runtime.api.ExecutionContext;
-import org.apache.tez.runtime.api.ObjectRegistry;
-import org.apache.tez.runtime.api.impl.EventMetaData;
-import org.apache.tez.runtime.api.impl.TaskSpec;
-import org.apache.tez.runtime.api.impl.TezEvent;
-import org.apache.tez.runtime.api.impl.TezUmbilical;
-import org.apache.tez.runtime.internals.api.TaskReporterInterface;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import com.google.common.collect.Multimap;
-import com.google.common.util.concurrent.ListenableFuture;
-import com.google.common.util.concurrent.ListeningExecutorService;
-
-public class TezTaskRunner implements TezUmbilical, ErrorReporter {
-
-  private static final Logger LOG = LoggerFactory.getLogger(TezTaskRunner.class);
-
-  private final Configuration tezConf;
-  private final LogicalIOProcessorRuntimeTask task;
-  private final UserGroupInformation ugi;
-
-  private final TaskReporterInterface taskReporter;
-  private final ListeningExecutorService executor;
-  private volatile ListenableFuture<Void> taskFuture;
-  private volatile Thread waitingThread;
-  private volatile Thread taskRunner;
-  private volatile Throwable firstException;
-
-  // Effectively a duplicate check, since hadFatalError does the same thing.
-  private final AtomicBoolean fatalErrorSent = new AtomicBoolean(false);
-  private final AtomicBoolean taskRunning;
-  private final AtomicBoolean shutdownRequested = new AtomicBoolean(false);
-
-  public TezTaskRunner(Configuration tezConf, UserGroupInformation ugi, String[] localDirs,
-      TaskSpec taskSpec, int appAttemptNumber,
-      Map<String, ByteBuffer> serviceConsumerMetadata, Map<String, String> serviceProviderEnvMap,
-      Multimap<String, String> startedInputsMap, TaskReporterInterface taskReporter,
-      ListeningExecutorService executor, ObjectRegistry objectRegistry, String pid,
-      ExecutionContext executionContext, long memAvailable)
-          throws IOException {
-    this.tezConf = tezConf;
-    this.ugi = ugi;
-    this.taskReporter = taskReporter;
-    this.executor = executor;
-    task = new LogicalIOProcessorRuntimeTask(taskSpec, appAttemptNumber, tezConf, localDirs, this,
-        serviceConsumerMetadata, serviceProviderEnvMap, startedInputsMap, objectRegistry, pid,
-        executionContext, memAvailable);
-    taskRunning = new AtomicBoolean(false);
-  }
-
-  /**
-   * @return false if a shutdown message was received during task execution
-   * @throws TezException
-   * @throws IOException
-   */
-  public boolean run() throws InterruptedException, IOException, TezException {
-    waitingThread = Thread.currentThread();
-    taskRunning.set(true);
-    taskReporter.registerTask(task, this);
-    TaskRunnerCallable callable = new TaskRunnerCallable();
-    Throwable failureCause = null;
-    if (!Thread.currentThread().isInterrupted()) {
-      taskFuture = executor.submit(callable);
-    } else {
-      taskReporter.unregisterTask(task.getTaskAttemptID());
-      return isShutdownRequested();
-    }
-    try {
-      taskFuture.get();
-
-      // Task could signal a fatal error and return control, or a failure while registering success.
-      failureCause = firstException;
-
-    } catch (InterruptedException e) {
-      LOG.info("Interrupted while waiting for task to complete. Interrupting task");
-      taskFuture.cancel(true);
-      if (shutdownRequested.get()) {
-        LOG.info("Shutdown requested... returning");
-        return false;
-      }
-      if (firstException != null) {
-        failureCause = firstException;
-      } else {
-        // Interrupted for some other reason.
-        failureCause = e;
-      }
-    } catch (ExecutionException e) {
-      // Exception thrown by the run() method itself.
-      Throwable cause = e.getCause();
-      if (cause instanceof FSError) {
-        // Not immediately fatal, this is an error reported by Hadoop FileSystem
-        failureCause = cause;
-      } else if (cause instanceof Error) {
-        LOG.error("Exception of type Error.", cause);
-        sendFailure(cause, "Fatal Error cause TezChild exit.");
-        throw new TezException("Fatal Error cause TezChild exit.", cause);
-      } else {
-        failureCause = cause;
-      }
-    } finally {
-      // Clear the interrupted status of the blocking thread, in case it is set after the
-      // InterruptedException was invoked.
-      taskReporter.unregisterTask(task.getTaskAttemptID());
-      Thread.interrupted();
-    }
-
-    if (failureCause != null) {
-      if (failureCause instanceof FSError) {
-        // Not immediately fatal, this is an error reported by Hadoop FileSystem
-        LOG.info("Encountered an FSError while executing task: " + task.getTaskAttemptID(),
-            failureCause);
-        throw (FSError) failureCause;
-      } else if (failureCause instanceof Error) {
-        LOG.error("Exception of type Error.", failureCause);
-        sendFailure(failureCause, "Fatal error cause TezChild exit.");
-        throw new TezException("Fatal error cause TezChild exit.", failureCause);
-      } else {
-        if (failureCause instanceof IOException) {
-          throw (IOException) failureCause;
-        } else if (failureCause instanceof TezException) {
-          throw (TezException) failureCause;
-        } else if (failureCause instanceof InterruptedException) {
-          throw (InterruptedException) failureCause;
-        } else {
-          throw new TezException(failureCause);
-        }
-      }
-    }
-    return isShutdownRequested();
-  }
-
-  private boolean isShutdownRequested() {
-    if (shutdownRequested.get()) {
-      LOG.info("Shutdown requested... returning");
-      return false;
-    }
-    return true;
-  }
-
-  private class TaskRunnerCallable extends CallableWithNdc<Void> {
-    @Override
-    protected Void callInternal() throws Exception {
-      try {
-        return ugi.doAs(new PrivilegedExceptionAction<Void>() {
-          @Override
-          public Void run() throws Exception {
-            try {
-              taskRunner = Thread.currentThread();
-              LOG.info("Initializing task" + ", taskAttemptId=" + task.getTaskAttemptID());
-              task.initialize();
-              if (!Thread.currentThread().isInterrupted() && firstException == null) {
-                LOG.info("Running task, taskAttemptId=" + task.getTaskAttemptID());
-                task.run();
-                maybeInterruptWaitingThread();
-
-                LOG.info("Closing task, taskAttemptId=" + task.getTaskAttemptID());
-                task.close();
-                task.setFrameworkCounters();
-              }
-              LOG.info("Task completed, taskAttemptId=" + task.getTaskAttemptID()
-                  + ", fatalErrorOccurred=" + (firstException != null));
-              if (firstException == null) {
-                try {
-                  taskReporter.taskSucceeded(task.getTaskAttemptID());
-                } catch (IOException e) {
-                  LOG.warn("Heartbeat failure caused by communication failure", e);
-                  maybeRegisterFirstException(e);
-                  // Falling off, since the runner thread checks for the registered exception.
-                } catch (TezException e) {
-                  LOG.warn("Heartbeat failure reported by AM", e);
-                  maybeRegisterFirstException(e);
-                  // Falling off, since the runner thread checks for the registered exception.
-                }
-              }
-              return null;
-            } catch (Throwable cause) {
-              if (Thread.currentThread().isInterrupted()) {
-                LOG.info("TaskRunnerCallable interrupted=" + Thread.currentThread().isInterrupted()
-                    + ", shutdownRequest=" + shutdownRequested.get());
-                Thread.currentThread().interrupt();
-                return null;
-              }
-              if (cause instanceof FSError) {
-                // Not immediately fatal, this is an error reported by Hadoop FileSystem
-                maybeRegisterFirstException(cause);
-                LOG.info("Encountered an FSError while executing task: " + task.getTaskAttemptID(),
-                    cause);
-                try {
-                  sendFailure(cause, "FS Error in Child JVM");
-                } catch (Exception ignored) {
-                  // Ignored since another cause is already known
-                  LOG.info(
-                      "Ignoring the following exception since a previous exception is already registered",
-                      ignored.getClass().getName());
-                  if (LOG.isTraceEnabled()) {
-                    LOG.trace("Ignored exception is", ignored);
-                  }
-                }
-                throw (FSError) cause;
-              } else if (cause instanceof Error) {
-                LOG.error("Exception of type Error.", cause);
-                sendFailure(cause, "Fatal Error cause TezChild exit.");
-                throw new TezException("Fatal Error cause TezChild exit.", cause);
-              } else {
-                if (cause instanceof UndeclaredThrowableException) {
-                  cause = ((UndeclaredThrowableException) cause).getCause();
-                }
-                maybeRegisterFirstException(cause);
-                LOG.info("Encountered an error while executing task: " + task.getTaskAttemptID(),
-                    cause);
-                try {
-                  sendFailure(cause, "Failure while running task");
-                } catch (Exception ignored) {
-                  // Ignored since another cause is already known
-                  LOG.info(
-                      "Ignoring the following exception since a previous exception is already registered",
-                      ignored.getClass().getName());
-                  if (LOG.isTraceEnabled()) {
-                    LOG.trace("Ignored exception is", ignored);
-                  }
-                }
-                if (cause instanceof IOException) {
-                  throw (IOException) cause;
-                } else if (cause instanceof TezException) {
-                  throw (TezException) cause;
-                } else {
-                  throw new TezException(cause);
-                }
-              }
-            } finally {
-              task.cleanup();
-            }
-          }
-        });
-      } finally {
-        taskRunning.set(false);
-      }
-    }
-
-    private void maybeInterruptWaitingThread() {
-      /**
-       * Possible that the processor is swallowing InterruptException of taskRunner.interrupt().
-       * In such case, interrupt the waitingThread based on the shutdownRequested flag, so that
-       * entire task gets cancelled.
-       */
-      if (shutdownRequested.get()) {
-        waitingThread.interrupt();
-      }
-    }
-  }
-
-  // should wait until all messages are sent to AM before TezChild shutdown
-  // if this method become async in future
-  private void sendFailure(Throwable t, String message) throws IOException, TezException {
-    if (!fatalErrorSent.getAndSet(true)) {
-      task.setFatalError(t, message);
-      task.setFrameworkCounters();
-      try {
-        taskReporter.taskFailed(task.getTaskAttemptID(), t, message, null);
-      } catch (IOException e) {
-        // A failure reason already exists, Comm error just logged.
-        LOG.warn("Heartbeat failure caused by communication failure", e);
-        throw e;
-      } catch (TezException e) {
-        // A failure reason already exists, Comm error just logged.
-        LOG.warn("Heartbeat failure reported by AM", e);
-        throw e;
-      }
-    } else {
-      LOG.warn("Ignoring fatal error since another error has already been reported", t);
-    }
-  }
-
-  @Override
-  public void addEvents(Collection<TezEvent> events) {
-    if (taskRunning.get()) {
-      taskReporter.addEvents(task.getTaskAttemptID(), events);
-    }
-  }
-
-  @Override
-  public synchronized void signalFatalError(TezTaskAttemptID taskAttemptID, Throwable t,
-      String message, EventMetaData sourceInfo) {
-    // This can be called before a task throws an exception or after it.
-    // If called before a task throws an exception
-    // - ensure a heartbeat is sent with the diagnostics, and sent only once.
-    // - interrupt the waiting thread, and make it throw the reported error.
-    // If called after a task throws an exception, the waiting task has already returned, no point
-    // interrupting it.
-    // This case can be effectively ignored (log), as long as the run() method ends up throwing the
-    // exception.
-    //
-    //
-    if (!fatalErrorSent.getAndSet(true)) {
-      maybeRegisterFirstException(t);
-      try {
-        taskReporter.taskFailed(taskAttemptID, t, getTaskDiagnosticsString(t, message), sourceInfo);
-      } catch (IOException e) {
-        // HeartbeatFailed. Don't need to propagate the heartbeat exception since a task exception
-        // occurred earlier.
-        LOG.warn("Heartbeat failure caused by communication failure", e);
-      } catch (TezException e) {
-        // HeartbeatFailed. Don't need to propagate the heartbeat exception since a task exception
-        // occurred earlier.
-        LOG.warn("Heartbeat failure reported by AM", e);
-      } finally {
-        // Wake up the waiting thread so that it can return control
-        waitingThread.interrupt();
-      }
-    }
-  }
-
-  @Override
-  public boolean canCommit(TezTaskAttemptID taskAttemptID) {
-    if (taskRunning.get()) {
-      try {
-        return taskReporter.canCommit(taskAttemptID);
-      } catch (IOException e) {
-        LOG.warn("Communication failure while trying to commit", e);
-        maybeRegisterFirstException(e);
-        waitingThread.interrupt();
-        // Not informing the task since it will be interrupted.
-        // TODO: Should this be sent to the task as well, current Processors, etc do not handle
-        // interrupts very well.
-        return false;
-      }
-    } else {
-      return false;
-    }
-  }
-
-  @Override
-  public synchronized void reportError(Throwable t) {
-   if (taskRunning.get()) {
-      LOG.error("TaskReporter reported error", t);
-      maybeRegisterFirstException(t);
-      waitingThread.interrupt();
-      // A race is possible between a task succeeding, and a subsequent timed heartbeat failing.
-      // These errors can be ignored, since a task can only succeed if the synchronous taskSucceeded
-      // method does not throw an exception, in which case task success is registered with the AM.
-      // Leave this handling to the next getTask / actual task.
-    } else {
-      LOG.info("Ignoring Communication failure since task with id=" + task.getTaskAttemptID()
-          + " is already complete");
-    }
-  }
-
-  private void abortRunningTask() {
-    if (!taskRunning.get()) {
-      LOG.info("Task is not running");
-      waitingThread.interrupt();
-      return;
-    }
-
-    if (taskRunning.get()) {
-      try {
-        task.abortTask();
-      } catch (Exception e) {
-        LOG.warn("Error when aborting the task", e);
-        try {
-          sendFailure(e, "Error when aborting the task");
-        } catch (Exception ignored) {
-          // Ignored.
-        }
-      }
-    }
-    //Interrupt the relevant threads.  TaskRunner should be interrupted preferably.
-    if (isTaskRunning()) {
-      LOG.info("Interrupting taskRunner=" + taskRunner.getName());
-      taskRunner.interrupt();
-    } else {
-      LOG.info("Interrupting waitingThread=" + waitingThread.getName());
-      waitingThread.interrupt();
-    }
-  }
-
-  private boolean isTaskRunning() {
-    return (taskRunning.get() && task.isRunning());
-  }
-
-  @Override
-  public void shutdownRequested() {
-    shutdownRequested.set(true);
-    abortRunningTask();
-  }
-
-  private String getTaskDiagnosticsString(Throwable t, String message) {
-    String diagnostics;
-    if (t != null && message != null) {
-      diagnostics = "exceptionThrown=" + ExceptionUtils.getStackTrace(t) + ", errorMessage="
-          + message;
-    } else if (t == null && message == null) {
-      diagnostics = "Unknown error";
-    } else {
-      diagnostics = t != null ? "exceptionThrown=" + ExceptionUtils.getStackTrace(t)
-          : " errorMessage=" + message;
-    }
-    return diagnostics;
-  }
-
-  private synchronized void maybeRegisterFirstException(Throwable t) {
-    if (firstException == null) {
-      firstException = t;
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution.java
----------------------------------------------------------------------
diff --git a/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution.java b/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution.java
deleted file mode 100644
index a99416a..0000000
--- a/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution.java
+++ /dev/null
@@ -1,362 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.tez.runtime.task;
-
-import static org.apache.tez.runtime.task.TaskExecutionTestHelpers.createTaskReporter;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.UUID;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.tez.dag.api.ProcessorDescriptor;
-import org.apache.tez.dag.api.TezConfiguration;
-import org.apache.tez.dag.api.TezException;
-import org.apache.tez.dag.api.UserPayload;
-import org.apache.tez.dag.records.TezDAGID;
-import org.apache.tez.dag.records.TezTaskAttemptID;
-import org.apache.tez.dag.records.TezTaskID;
-import org.apache.tez.dag.records.TezVertexID;
-import org.apache.tez.runtime.api.impl.ExecutionContextImpl;
-import org.apache.tez.runtime.api.impl.InputSpec;
-import org.apache.tez.runtime.api.impl.OutputSpec;
-import org.apache.tez.runtime.api.impl.TaskSpec;
-import org.apache.tez.runtime.common.resources.ScalingAllocator;
-import org.apache.tez.runtime.task.TaskExecutionTestHelpers.TestProcessor;
-import org.apache.tez.runtime.task.TaskExecutionTestHelpers.TezTaskUmbilicalForTest;
-import org.junit.AfterClass;
-import org.junit.Before;
-import org.junit.Test;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import com.google.common.collect.HashMultimap;
-import com.google.common.util.concurrent.ListeningExecutorService;
-import com.google.common.util.concurrent.MoreExecutors;
-
-// Tests in this class cannot be run in parallel.
-public class TestTaskExecution {
-
-  private static final Logger LOG = LoggerFactory.getLogger(TestTaskExecution.class);
-
-
-
-  private static final Configuration defaultConf = new Configuration();
-  private static final FileSystem localFs;
-  private static final Path workDir;
-
-  private static final ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
-
-  static {
-    defaultConf.set("fs.defaultFS", "file:///");
-    defaultConf.set(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
-        ScalingAllocator.class.getName());
-    try {
-      localFs = FileSystem.getLocal(defaultConf);
-      Path wd = new Path(System.getProperty("test.build.data", "/tmp"),
-          TestTaskExecution.class.getSimpleName());
-      workDir = localFs.makeQualified(wd);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-  }
-
-  @Before
-  public void reset() {
-    TestProcessor.reset();
-  }
-
-  @AfterClass
-  public static void shutdown() {
-    taskExecutor.shutdownNow();
-  }
-
-  @Test(timeout = 5000)
-  public void testSingleSuccessfulTask() throws IOException, InterruptedException, TezException,
-      ExecutionException {
-    ListeningExecutorService executor = null;
-    try {
-      ExecutorService rawExecutor = Executors.newFixedThreadPool(1);
-      executor = MoreExecutors.listeningDecorator(rawExecutor);
-      ApplicationId appId = ApplicationId.newInstance(10000, 1);
-      TezTaskUmbilicalForTest umbilical = new TezTaskUmbilicalForTest();
-      TaskReporter taskReporter = createTaskReporter(appId, umbilical);
-
-      TezTaskRunner taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          TestProcessor.CONF_EMPTY);
-      // Setup the executor
-      Future<Boolean> taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      // Signal the processor to go through
-      TestProcessor.signal();
-      boolean result = taskRunnerFuture.get();
-      assertTrue(result);
-      assertNull(taskReporter.currentCallable);
-      umbilical.verifyTaskSuccessEvent();
-    } finally {
-      executor.shutdownNow();
-    }
-  }
-
-  @Test(timeout = 5000)
-  public void testMultipleSuccessfulTasks() throws IOException, InterruptedException, TezException,
-      ExecutionException {
-
-    ListeningExecutorService executor = null;
-    try {
-      ExecutorService rawExecutor = Executors.newFixedThreadPool(1);
-      executor = MoreExecutors.listeningDecorator(rawExecutor);
-      ApplicationId appId = ApplicationId.newInstance(10000, 1);
-      TezTaskUmbilicalForTest umbilical = new TezTaskUmbilicalForTest();
-      TaskReporter taskReporter = createTaskReporter(appId, umbilical);
-
-      TezTaskRunner taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          TestProcessor.CONF_EMPTY);
-      // Setup the executor
-      Future<Boolean> taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      // Signal the processor to go through
-      TestProcessor.signal();
-      boolean result = taskRunnerFuture.get();
-      assertTrue(result);
-      assertNull(taskReporter.currentCallable);
-      umbilical.verifyTaskSuccessEvent();
-      umbilical.resetTrackedEvents();
-
-      taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          TestProcessor.CONF_EMPTY);
-      // Setup the executor
-      taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      // Signal the processor to go through
-      TestProcessor.signal();
-      result = taskRunnerFuture.get();
-      assertTrue(result);
-      assertNull(taskReporter.currentCallable);
-      umbilical.verifyTaskSuccessEvent();
-    } finally {
-      executor.shutdownNow();
-    }
-  }
-
-  // test task failed due to exception in Processor
-  @Test(timeout = 5000)
-  public void testFailedTask() throws IOException, InterruptedException, TezException {
-
-    ListeningExecutorService executor = null;
-    try {
-      ExecutorService rawExecutor = Executors.newFixedThreadPool(1);
-      executor = MoreExecutors.listeningDecorator(rawExecutor);
-      ApplicationId appId = ApplicationId.newInstance(10000, 1);
-      TezTaskUmbilicalForTest umbilical = new TezTaskUmbilicalForTest();
-      TaskReporter taskReporter = createTaskReporter(appId, umbilical);
-
-      TezTaskRunner taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          TestProcessor.CONF_THROW_TEZ_EXCEPTION);
-      // Setup the executor
-      Future<Boolean> taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      // Signal the processor to go through
-      TestProcessor.awaitStart();
-      TestProcessor.signal();
-      try {
-        taskRunnerFuture.get();
-        fail("Expecting the task to fail");
-      } catch (ExecutionException e) {
-        Throwable cause = e.getCause();
-        LOG.info(cause.getClass().getName());
-        assertTrue(cause instanceof TezException);
-      }
-
-      assertNull(taskReporter.currentCallable);
-      umbilical.verifyTaskFailedEvent("Failure while running task:org.apache.tez.dag.api.TezException: TezException");
-    } finally {
-      executor.shutdownNow();
-    }
-  }
-
-  // Test task failed due to Processor class not found
-  @Test(timeout = 5000)
-  public void testFailedTask2() throws IOException, InterruptedException, TezException {
-
-    ListeningExecutorService executor = null;
-    try {
-      ExecutorService rawExecutor = Executors.newFixedThreadPool(1);
-      executor = MoreExecutors.listeningDecorator(rawExecutor);
-      ApplicationId appId = ApplicationId.newInstance(10000, 1);
-      TezTaskUmbilicalForTest umbilical = new TezTaskUmbilicalForTest();
-      TaskReporter taskReporter = createTaskReporter(appId, umbilical);
-
-      TezTaskRunner taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          "NotExitedProcessor", TestProcessor.CONF_THROW_TEZ_EXCEPTION);
-      // Setup the executor
-      Future<Boolean> taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      try {
-        taskRunnerFuture.get();
-      } catch (ExecutionException e) {
-        Throwable cause = e.getCause();
-        LOG.info(cause.getClass().getName());
-        assertTrue(cause instanceof TezException);
-      }
-      assertNull(taskReporter.currentCallable);
-      umbilical.verifyTaskFailedEvent("Failure while running task:org.apache.tez.dag.api.TezUncheckedException: "
-            + "Unable to load class: NotExitedProcessor");
-    } finally {
-      executor.shutdownNow();
-    }
-  }
-
-  @Test(timeout = 5000)
-  public void testHeartbeatException() throws IOException, InterruptedException, TezException {
-
-    ListeningExecutorService executor = null;
-    try {
-      ExecutorService rawExecutor = Executors.newFixedThreadPool(1);
-      executor = MoreExecutors.listeningDecorator(rawExecutor);
-      ApplicationId appId = ApplicationId.newInstance(10000, 1);
-      TezTaskUmbilicalForTest umbilical = new TezTaskUmbilicalForTest();
-      TaskReporter taskReporter = createTaskReporter(appId, umbilical);
-
-      TezTaskRunner taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          TestProcessor.CONF_EMPTY);
-      // Setup the executor
-      Future<Boolean> taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      // Signal the processor to go through
-      TestProcessor.awaitStart();
-      umbilical.signalThrowException();
-      umbilical.awaitRegisteredEvent();
-      // Not signaling an actual start to verify task interruption
-      try {
-        taskRunnerFuture.get();
-        fail("Expecting the task to fail");
-      } catch (ExecutionException e) {
-        Throwable cause = e.getCause();
-        assertTrue(cause instanceof IOException);
-        assertTrue(cause.getMessage().contains(TaskExecutionTestHelpers.HEARTBEAT_EXCEPTION_STRING));
-      }
-      TestProcessor.awaitCompletion();
-      assertTrue(TestProcessor.wasInterrupted());
-      assertNull(taskReporter.currentCallable);
-      // No completion events since umbilical communication already failed.
-      umbilical.verifyNoCompletionEvents();
-    } finally {
-      executor.shutdownNow();
-    }
-  }
-
-  @Test(timeout = 5000)
-  public void testHeartbeatShouldDie() throws IOException, InterruptedException, TezException,
-      ExecutionException {
-
-    ListeningExecutorService executor = null;
-    try {
-      ExecutorService rawExecutor = Executors.newFixedThreadPool(1);
-      executor = MoreExecutors.listeningDecorator(rawExecutor);
-      ApplicationId appId = ApplicationId.newInstance(10000, 1);
-      TezTaskUmbilicalForTest umbilical = new TezTaskUmbilicalForTest();
-      TaskReporter taskReporter = createTaskReporter(appId, umbilical);
-
-      TezTaskRunner taskRunner = createTaskRunner(appId, umbilical, taskReporter, executor,
-          TestProcessor.CONF_EMPTY);
-      // Setup the executor
-      Future<Boolean> taskRunnerFuture = taskExecutor.submit(new TaskRunnerCallable1ForTest(taskRunner));
-      // Signal the processor to go through
-      TestProcessor.awaitStart();
-      umbilical.signalSendShouldDie();
-      umbilical.awaitRegisteredEvent();
-      // Not signaling an actual start to verify task interruption
-
-      boolean result = taskRunnerFuture.get();
-      assertFalse(result);
-
-      TestProcessor.awaitCompletion();
-      assertTrue(TestProcessor.wasInterrupted());
-      assertNull(taskReporter.currentCallable);
-      // TODO Is this statement correct ?
-      // No completion events since shouldDie was requested by the AM, which should have killed the
-      // task.
-      umbilical.verifyNoCompletionEvents();
-    } finally {
-      executor.shutdownNow();
-    }
-  }
-
-  // Potential new tests
-  // Different states - initialization failure, close failure
-  // getTask states
-
-  private static class TaskRunnerCallable1ForTest implements Callable<Boolean> {
-    private final TezTaskRunner taskRunner;
-
-    public TaskRunnerCallable1ForTest(TezTaskRunner taskRunner) {
-      this.taskRunner = taskRunner;
-    }
-
-    @Override
-    public Boolean call() throws Exception {
-      return taskRunner.run();
-    }
-  }
-
-
-
-
-
-  private TezTaskRunner createTaskRunner(ApplicationId appId, TezTaskUmbilicalForTest umbilical,
-      TaskReporter taskReporter, ListeningExecutorService executor, byte[] processorConf)
-      throws IOException {
-    return createTaskRunner(appId, umbilical, taskReporter, executor, TestProcessor.class.getName(),
-        processorConf);
-  }
-
-  private TezTaskRunner createTaskRunner(ApplicationId appId, TezTaskUmbilicalForTest umbilical,
-      TaskReporter taskReporter, ListeningExecutorService executor, String processorClass, byte[] processorConf) throws IOException{
-    TezConfiguration tezConf = new TezConfiguration(defaultConf);
-    UserGroupInformation ugi = UserGroupInformation.getCurrentUser();
-    Path testDir = new Path(workDir, UUID.randomUUID().toString());
-    String[] localDirs = new String[] { testDir.toString() };
-
-    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
-    TezVertexID vertexId = TezVertexID.getInstance(dagId, 1);
-    TezTaskID taskId = TezTaskID.getInstance(vertexId, 1);
-    TezTaskAttemptID taskAttemptId = TezTaskAttemptID.getInstance(taskId, 1);
-    ProcessorDescriptor processorDescriptor = ProcessorDescriptor.create(processorClass)
-        .setUserPayload(UserPayload.create(ByteBuffer.wrap(processorConf)));
-    TaskSpec taskSpec = new TaskSpec(taskAttemptId, "dagName", "vertexName", -1, processorDescriptor,
-        new ArrayList<InputSpec>(), new ArrayList<OutputSpec>(), null);
-
-    TezTaskRunner taskRunner = new TezTaskRunner(tezConf, ugi, localDirs, taskSpec, 1,
-        new HashMap<String, ByteBuffer>(), new HashMap<String, String>(), HashMultimap.<String, String> create(), taskReporter,
-        executor, null, "", new ExecutionContextImpl("localhost"), Runtime.getRuntime().maxMemory());
-    return taskRunner;
-  }
-
-
-}

http://git-wip-us.apache.org/repos/asf/tez/blob/8b278ea8/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution2.java
----------------------------------------------------------------------
diff --git a/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution2.java b/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution2.java
index 12d9d3f..ce9095a 100644
--- a/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution2.java
+++ b/tez-runtime-internals/src/test/java/org/apache/tez/runtime/task/TestTaskExecution2.java
@@ -89,7 +89,7 @@ public class TestTaskExecution2 {
     try {
       localFs = FileSystem.getLocal(defaultConf);
       Path wd = new Path(System.getProperty("test.build.data", "/tmp"),
-          TestTaskExecution.class.getSimpleName());
+          TestTaskExecution2.class.getSimpleName());
       workDir = localFs.makeQualified(wd);
     } catch (IOException e) {
       throw new RuntimeException(e);


Mime
View raw message