tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bi...@apache.org
Subject [1/2] TEZ-338. Determine reduce task parallelism (bikas)
Date Tue, 20 Aug 2013 00:10:31 GMT
Updated Branches:
  refs/heads/master 67c0c17c4 -> e368ede8d


http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexScheduler.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexScheduler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexScheduler.java
index a2248ca..22cd885 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexScheduler.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexScheduler.java
@@ -18,13 +18,21 @@
 
 package org.apache.tez.dag.app.dag.impl;
 
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.common.TezUtils;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
+import org.apache.tez.dag.api.ProcessorDescriptor;
+import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.app.dag.Task;
 import org.apache.tez.dag.app.dag.Vertex;
@@ -32,6 +40,8 @@ import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.engine.records.TezDependentTaskCompletionEvent;
+import org.apache.tez.mapreduce.hadoop.MRHelpers;
 import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
@@ -42,9 +52,179 @@ import static org.mockito.Mockito.*;
 public class TestVertexScheduler {
 
   @SuppressWarnings({ "unchecked", "rawtypes" })
-  @Test
-  public void testBipartiteSlowStartVertexScheduler() {
-    BipartiteSlowStartVertexScheduler scheduler = null;
+  @Test(timeout = 5000)
+  public void testShuffleVertexManagerAutoParallelism() throws IOException {
+    Configuration conf = new Configuration();
+    conf.setBoolean(
+        TezConfiguration.TEZ_AM_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+        true);
+    conf.setLong(TezConfiguration.TEZ_AM_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 1000L);
+    ShuffleVertexManager scheduler = null;
+    TezDAGID dagId = new TezDAGID("1", 1, 1);
+    HashMap<Vertex, EdgeProperty> mockInputVertices = 
+        new HashMap<Vertex, EdgeProperty>();
+    Vertex mockSrcVertex1 = mock(Vertex.class);
+    TezVertexID mockSrcVertexId1 = new TezVertexID(dagId, 1);
+    EdgeProperty eProp1 = new EdgeProperty(
+        EdgeProperty.ConnectionPattern.BIPARTITE,
+        EdgeProperty.SourceType.STABLE, new OutputDescriptor("out"),
+        new InputDescriptor("in"));
+    when(mockSrcVertex1.getVertexId()).thenReturn(mockSrcVertexId1);
+    Vertex mockSrcVertex2 = mock(Vertex.class);
+    TezVertexID mockSrcVertexId2 = new TezVertexID(dagId, 2);
+    EdgeProperty eProp2 = new EdgeProperty(
+        EdgeProperty.ConnectionPattern.BIPARTITE,
+        EdgeProperty.SourceType.STABLE, new OutputDescriptor("out"),
+        new InputDescriptor("in"));
+    when(mockSrcVertex2.getVertexId()).thenReturn(mockSrcVertexId2);
+    Vertex mockSrcVertex3 = mock(Vertex.class);
+    TezVertexID mockSrcVertexId3 = new TezVertexID(dagId, 3);
+    EdgeProperty eProp3 = new EdgeProperty(
+        EdgeProperty.ConnectionPattern.ONE_TO_ALL,
+        EdgeProperty.SourceType.STABLE, new OutputDescriptor("out"),
+        new InputDescriptor("in"));
+    when(mockSrcVertex3.getVertexId()).thenReturn(mockSrcVertexId3);
+    
+    Vertex mockManagedVertex = mock(Vertex.class);
+    TezVertexID mockManagedVertexId = new TezVertexID(dagId, 4);
+    when(mockManagedVertex.getVertexId()).thenReturn(mockManagedVertexId);
+    when(mockManagedVertex.getInputVertices()).thenReturn(mockInputVertices);
+    
+    TezDependentTaskCompletionEvent mockEvent = 
+        mock(TezDependentTaskCompletionEvent.class);
+    
+    mockInputVertices.put(mockSrcVertex1, eProp1);
+    mockInputVertices.put(mockSrcVertex2, eProp2);
+    mockInputVertices.put(mockSrcVertex3, eProp3);
+
+    // check initialization
+    scheduler = createScheduler(conf, mockManagedVertex, 0.1f, 0.1f);
+    Assert.assertTrue(scheduler.bipartiteSources.size() == 2);
+    Assert.assertTrue(scheduler.bipartiteSources.containsKey(mockSrcVertexId1));
+    Assert.assertTrue(scheduler.bipartiteSources.containsKey(mockSrcVertexId2));
+    
+    final HashMap<TezTaskID, Task> managedTasks = new HashMap<TezTaskID, Task>();
+    final TezTaskID mockTaskId1 = new TezTaskID(mockManagedVertexId, 0);
+    managedTasks.put(mockTaskId1, null);
+    final TezTaskID mockTaskId2 = new TezTaskID(mockManagedVertexId, 1);
+    managedTasks.put(mockTaskId2, null);
+    final TezTaskID mockTaskId3 = new TezTaskID(mockManagedVertexId, 2);
+    managedTasks.put(mockTaskId3, null);
+    final TezTaskID mockTaskId4 = new TezTaskID(mockManagedVertexId, 3);
+    managedTasks.put(mockTaskId4, null);
+    
+    when(mockManagedVertex.getTotalTasks()).thenReturn(managedTasks.size());
+    when(mockManagedVertex.getTasks()).thenReturn(managedTasks);
+    
+    final HashSet<TezTaskID> scheduledTasks = new HashSet<TezTaskID>();
+    doAnswer(new Answer() {
+      public Object answer(InvocationOnMock invocation) {
+          Object[] args = invocation.getArguments();
+          scheduledTasks.clear();
+          scheduledTasks.addAll((Collection<TezTaskID>)args[0]); 
+          return null;
+      }}).when(mockManagedVertex).scheduleTasks(anyCollection());
+    
+    final List<byte[]> taskPayloads = new ArrayList<byte[]>();
+    
+    doAnswer(new Answer() {
+      public Object answer(InvocationOnMock invocation) {
+          managedTasks.remove(mockTaskId3);
+          managedTasks.remove(mockTaskId4);
+          taskPayloads.clear();
+          taskPayloads.addAll((List<byte[]>)invocation.getArguments()[1]);
+          return null;
+      }}).when(mockManagedVertex).setParallelism(eq(2), anyList());
+    
+    // source vertices have 0 tasks. immediate start of all managed tasks
+    when(mockSrcVertex1.getTotalTasks()).thenReturn(0);
+    when(mockSrcVertex2.getTotalTasks()).thenReturn(0);
+    scheduler.onVertexStarted();
+    Assert.assertTrue(scheduler.pendingTasks.isEmpty());
+    Assert.assertTrue(scheduledTasks.size() == 4); // all tasks scheduled
+    scheduledTasks.clear();
+    
+    when(mockSrcVertex1.getTotalTasks()).thenReturn(2);
+    when(mockSrcVertex2.getTotalTasks()).thenReturn(2);
+
+    TezTaskAttemptID mockSrcAttemptId11 = 
+        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId1, 0), 0);
+    TezTaskAttemptID mockSrcAttemptId12 = 
+        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId1, 1), 0);
+    TezTaskAttemptID mockSrcAttemptId21 = 
+        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId2, 0), 0);
+    TezTaskAttemptID mockSrcAttemptId31 = 
+        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId3, 0), 0);
+
+    // parallelism not change due to large data size
+    when(mockEvent.getDataSize()).thenReturn(5000L);
+    scheduler = createScheduler(conf, mockManagedVertex, 0.1f, 0.1f);
+    scheduler.onVertexStarted();
+    Assert.assertTrue(scheduler.pendingTasks.size() == 4); // no tasks scheduled
+    Assert.assertTrue(scheduler.numSourceTasks == 4);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
+    // managedVertex tasks reduced
+    verify(mockManagedVertex, times(0)).setParallelism(anyInt(), anyList());
+    Assert.assertEquals(0, scheduler.pendingTasks.size()); // all tasks scheduled
+    Assert.assertEquals(4, scheduledTasks.size());
+    Assert.assertEquals(1, scheduler.numSourceTasksCompleted);
+    Assert.assertEquals(5000L, scheduler.completedSourceTasksOutputSize);
+    
+    // parallelism changed due to small data size
+    when(mockEvent.getDataSize()).thenReturn(500L);
+    scheduledTasks.clear();
+    Configuration procConf = new Configuration();
+    ProcessorDescriptor procDesc = new ProcessorDescriptor("REDUCE");
+    procDesc.setUserPayload(MRHelpers.createUserPayloadFromConf(procConf));
+    when(mockManagedVertex.getProcessorDescriptor()).thenReturn(procDesc);
+    
+    scheduler = createScheduler(conf, mockManagedVertex, 0.5f, 0.5f);
+    scheduler.onVertexStarted();
+    Assert.assertEquals(4, scheduler.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(4, scheduler.numSourceTasks);
+    // task completion from non-bipartite stage does nothing
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId31, mockEvent);
+    Assert.assertEquals(4, scheduler.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(4, scheduler.numSourceTasks);
+    Assert.assertEquals(0, scheduler.numSourceTasksCompleted);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
+    Assert.assertEquals(4, scheduler.pendingTasks.size());
+    Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
+    Assert.assertEquals(1, scheduler.numSourceTasksCompleted);
+    Assert.assertEquals(500L, scheduler.completedSourceTasksOutputSize);
+    // ignore duplicate completion
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
+    Assert.assertEquals(4, scheduler.pendingTasks.size());
+    Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
+    Assert.assertEquals(1, scheduler.numSourceTasksCompleted);
+    Assert.assertEquals(500L, scheduler.completedSourceTasksOutputSize);
+    
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId12, mockEvent);
+    // managedVertex tasks reduced
+    verify(mockManagedVertex).setParallelism(eq(2), anyList());
+    Assert.assertEquals(2, taskPayloads.size());
+    Assert.assertEquals(0, scheduler.pendingTasks.size()); // all tasks scheduled
+    Assert.assertEquals(2, scheduledTasks.size());
+    Assert.assertTrue(scheduledTasks.contains(mockTaskId1));
+    Assert.assertTrue(scheduledTasks.contains(mockTaskId2));
+    Assert.assertEquals(2, scheduler.numSourceTasksCompleted);
+    Assert.assertEquals(1000L, scheduler.completedSourceTasksOutputSize);
+    Configuration taskConf = TezUtils.createConfFromUserPayload(taskPayloads.get(0));
+    Assert.assertEquals(2,
+        taskConf.getInt(TezJobConfig.TEZ_ENGINE_SHUFFLE_PARTITION_RANGE, 0));
+    taskConf = TezUtils.createConfFromUserPayload(taskPayloads.get(1));
+    Assert.assertEquals(2,
+        taskConf.getInt(TezJobConfig.TEZ_ENGINE_SHUFFLE_PARTITION_RANGE, 0));
+    // more completions dont cause recalculation of parallelism
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId21, mockEvent);
+    verify(mockManagedVertex).setParallelism(eq(2), anyList());
+  }
+  
+  @SuppressWarnings({ "unchecked", "rawtypes" })
+  @Test(timeout = 5000)
+  public void testShuffleVertexManagerSlowStart() {
+    Configuration conf = new Configuration();
+    ShuffleVertexManager scheduler = null;
     TezDAGID dagId = new TezDAGID("1", 1, 1);
     HashMap<Vertex, EdgeProperty> mockInputVertices = 
         new HashMap<Vertex, EdgeProperty>();
@@ -74,12 +254,14 @@ public class TestVertexScheduler {
     TezVertexID mockManagedVertexId = new TezVertexID(dagId, 3);
     when(mockManagedVertex.getVertexId()).thenReturn(mockManagedVertexId);
     when(mockManagedVertex.getInputVertices()).thenReturn(mockInputVertices);
+    
+    TezDependentTaskCompletionEvent mockEvent = 
+        mock(TezDependentTaskCompletionEvent.class);
 
     // fail if there is no bipartite src vertex
     mockInputVertices.put(mockSrcVertex3, eProp3);
     try {
-      scheduler = new BipartiteSlowStartVertexScheduler(mockManagedVertex,
-          0.1f, 0.1f);
+      scheduler = createScheduler(conf, mockManagedVertex, 0.1f, 0.1f);
      Assert.assertFalse(true);
     } catch (TezUncheckedException e) {
       Assert.assertTrue(e.getMessage().contains(
@@ -90,8 +272,7 @@ public class TestVertexScheduler {
     mockInputVertices.put(mockSrcVertex2, eProp2);
     
     // check initialization
-    scheduler = 
-        new BipartiteSlowStartVertexScheduler(mockManagedVertex, 0.1f, 0.1f);
+    scheduler = createScheduler(conf, mockManagedVertex, 0.1f, 0.1f);
     Assert.assertTrue(scheduler.bipartiteSources.size() == 2);
     Assert.assertTrue(scheduler.bipartiteSources.containsKey(mockSrcVertexId1));
     Assert.assertTrue(scheduler.bipartiteSources.containsKey(mockSrcVertexId2));
@@ -127,131 +308,158 @@ public class TestVertexScheduler {
     when(mockSrcVertex2.getTotalTasks()).thenReturn(2);
 
     try {
-      // source vertex have some tasks. min, max == 0
-      scheduler = new BipartiteSlowStartVertexScheduler(mockManagedVertex, 0, 0);
+      // source vertex have some tasks. min < 0.
+      scheduler = createScheduler(conf, mockManagedVertex, -0.1f, 0);
+      Assert.assertTrue(false); // should not come here
+    } catch (IllegalArgumentException e) {
+      Assert.assertTrue(e.getMessage().contains(
+          "Invalid values for slowStartMinSrcCompletionFraction"));
+    }
+    
+    try {
+      // source vertex have some tasks. min > max
+      scheduler = createScheduler(conf, mockManagedVertex, 0.5f, 0.3f);
       Assert.assertTrue(false); // should not come here
     } catch (IllegalArgumentException e) {
       Assert.assertTrue(e.getMessage().contains(
           "Invalid values for slowStartMinSrcCompletionFraction"));
     }
     
+    // source vertex have some tasks. min, max == 0
+    scheduler = createScheduler(conf, mockManagedVertex, 0, 0);
+    scheduler.onVertexStarted();
+    Assert.assertTrue(scheduler.numSourceTasks == 4);
+    Assert.assertTrue(scheduler.totalTasksToSchedule == 3);
+    Assert.assertTrue(scheduler.numSourceTasksCompleted == 0);
+    Assert.assertTrue(scheduler.pendingTasks.isEmpty());
+    Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
+
     TezTaskAttemptID mockSrcAttemptId11 = 
         new TezTaskAttemptID(new TezTaskID(mockSrcVertexId1, 0), 0);
     TezTaskAttemptID mockSrcAttemptId12 = 
-        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId1, 0), 1);
+        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId1, 1), 0);
     TezTaskAttemptID mockSrcAttemptId21 = 
         new TezTaskAttemptID(new TezTaskID(mockSrcVertexId2, 0), 0);
     TezTaskAttemptID mockSrcAttemptId22 = 
-        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId2, 0), 1);
+        new TezTaskAttemptID(new TezTaskID(mockSrcVertexId2, 1), 0);
     TezTaskAttemptID mockSrcAttemptId31 = 
         new TezTaskAttemptID(new TezTaskID(mockSrcVertexId3, 0), 0);
     
     // min, max > 0 and min == max
-    scheduler = 
-        new BipartiteSlowStartVertexScheduler(mockManagedVertex, 0.25f, 0.25f);
+    scheduler = createScheduler(conf, mockManagedVertex, 0.25f, 0.25f);
     scheduler.onVertexStarted();
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
     // task completion from non-bipartite stage does nothing
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId31);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId31, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 0);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId11);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 1);
     
-    // min, max > 0 and min == max
-    scheduler = 
-        new BipartiteSlowStartVertexScheduler(mockManagedVertex, 1.0f, 1.0f);
+    // min, max > 0 and min == max == absolute max 1.0
+    scheduler = createScheduler(conf, mockManagedVertex, 1.0f, 1.0f);
     scheduler.onVertexStarted();
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
     // task completion from non-bipartite stage does nothing
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId31);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId31, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 0);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId11);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 1);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId12);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId12, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 2);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId21);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId21, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 3);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId22);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId22, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 4);
     
     // min, max > 0 and min == max
-    scheduler = 
-        new BipartiteSlowStartVertexScheduler(mockManagedVertex, 1.0f, 1.0f);
+    scheduler = createScheduler(conf, mockManagedVertex, 1.0f, 1.0f);
     scheduler.onVertexStarted();
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
     // task completion from non-bipartite stage does nothing
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId31);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId31, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 0);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId11);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 1);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId12);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId12, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 2);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId21);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId21, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 3);
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 3);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId22);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId22, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 4);
     
     // min, max > and min < max
-    scheduler = 
-        new BipartiteSlowStartVertexScheduler(mockManagedVertex, 0.25f, 0.75f);
+    scheduler = createScheduler(conf, mockManagedVertex, 0.25f, 0.75f);
     scheduler.onVertexStarted();
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId11);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId12);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId12, mockEvent);
+    Assert.assertTrue(scheduler.pendingTasks.size() == 2);
+    Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
+    Assert.assertTrue(scheduler.numSourceTasksCompleted == 2);
+    // completion of same task again should not get counted
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId12, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 2);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 2);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId21);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId21, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 0);
     Assert.assertTrue(scheduledTasks.size() == 2); // 2 tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 3);
     scheduledTasks.clear();
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId22); // we are done. no action
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId22, mockEvent); // we are done. no action
     Assert.assertTrue(scheduler.pendingTasks.size() == 0);
     Assert.assertTrue(scheduledTasks.size() == 0); // no task scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 4);
 
     // min, max > and min < max
-    scheduler = 
-        new BipartiteSlowStartVertexScheduler(mockManagedVertex, 0.25f, 1.0f);
+    scheduler = createScheduler(conf, mockManagedVertex, 0.25f, 1.0f);
     scheduler.onVertexStarted();
     Assert.assertTrue(scheduler.pendingTasks.size() == 3); // no tasks scheduled
     Assert.assertTrue(scheduler.numSourceTasks == 4);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId11);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId12);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId11, mockEvent);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId12, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 2);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 2);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId21);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId21, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 1);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 3);
-    scheduler.onSourceTaskCompleted(mockSrcAttemptId22);
+    scheduler.onSourceTaskCompleted(mockSrcAttemptId22, mockEvent);
     Assert.assertTrue(scheduler.pendingTasks.size() == 0);
     Assert.assertTrue(scheduledTasks.size() == 1); // no task scheduled
     Assert.assertTrue(scheduler.numSourceTasksCompleted == 4);
 
   }
+  
+  private ShuffleVertexManager createScheduler(Configuration conf, 
+      Vertex vertex, float min, float max) {
+    ShuffleVertexManager scheduler = new ShuffleVertexManager(vertex);
+    conf.setFloat(TezConfiguration.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, min);
+    conf.setFloat(TezConfiguration.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, max);
+    scheduler.initialize(conf);
+    return scheduler;
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
index a2e600b..9566993 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskScheduler.java
@@ -96,7 +96,6 @@ public class TestTaskScheduler {
         .getDrainableAppCallback();
 
     Configuration conf = new Configuration();
-    conf.setBoolean(TezConfiguration.TEZ_AM_AGGRESSIVE_SCHEDULING, false);
     conf.setBoolean(TezConfiguration.TEZ_AM_CONTAINER_REUSE_ENABLED, false);
     int interval = 100;
     conf.setInt(TezConfiguration.TEZ_AM_RM_HEARTBEAT_INTERVAL_MS_MAX, interval);

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-engine-api/src/main/java/org/apache/tez/engine/records/TezDependentTaskCompletionEvent.java
----------------------------------------------------------------------
diff --git a/tez-engine-api/src/main/java/org/apache/tez/engine/records/TezDependentTaskCompletionEvent.java b/tez-engine-api/src/main/java/org/apache/tez/engine/records/TezDependentTaskCompletionEvent.java
index cba752e..5fe6c9b 100644
--- a/tez-engine-api/src/main/java/org/apache/tez/engine/records/TezDependentTaskCompletionEvent.java
+++ b/tez-engine-api/src/main/java/org/apache/tez/engine/records/TezDependentTaskCompletionEvent.java
@@ -47,7 +47,9 @@ public class TezDependentTaskCompletionEvent implements Writable {
   private String taskTrackerHttp;
   private int taskRunTime; // using int since runtime is the time difference
   private TezTaskAttemptID taskAttemptId;
+  private long dataSize;
   Status status;
+  byte[] userPayload;
   // TODO TEZAM2 Get rid of the isMap field. Job specific type information can be determined from TaskAttemptId.getTaskType
 //  boolean isMap = false;
   public static final TezDependentTaskCompletionEvent[] EMPTY_ARRAY = 
@@ -62,24 +64,35 @@ public class TezDependentTaskCompletionEvent implements Writable {
    * per event for each job. 
    * @param eventId event id, event id should be unique and assigned in
    *  incrementally, starting from 0. 
-   * @param taskId task id
+   * @param taskAttemptId task id
    * @param status task's status 
    * @param taskTrackerHttp task tracker's host:port for http. 
    */
   public TezDependentTaskCompletionEvent(int eventId, 
-                             TezTaskAttemptID taskId,
+                             TezTaskAttemptID taskAttemptId,
 //                             boolean isMap,
                              Status status, 
                              String taskTrackerHttp,
-                             int runTime){
+                             int runTime,
+                             long dataSize){
       
-    this.taskAttemptId = taskId;
+    this.taskAttemptId = taskAttemptId;
 //    this.isMap = isMap;
     this.eventId = eventId; 
     this.status =status; 
     this.taskTrackerHttp = taskTrackerHttp;
     this.taskRunTime = runTime;
+    this.dataSize = dataSize;
   }
+  
+  public TezDependentTaskCompletionEvent clone() {
+    TezDependentTaskCompletionEvent clone = new TezDependentTaskCompletionEvent(
+        this.eventId, this.taskAttemptId, this.status, this.taskTrackerHttp,
+        this.taskRunTime, this.dataSize);
+    
+    return clone;
+  }
+  
   /**
    * Returns event Id. 
    * @return event id
@@ -117,6 +130,20 @@ public class TezDependentTaskCompletionEvent implements Writable {
   public int getTaskRunTime() {
     return taskRunTime;
   }
+  
+  /**
+   * Return size of output produced by the task
+   */
+  public long getDataSize() {
+    return dataSize;
+  }
+  
+  /**
+   * @return user payload. Maybe null
+   */
+  public byte[] getUserPayload() {
+    return userPayload;
+  }
 
   /**
    * Set the task completion time
@@ -157,6 +184,14 @@ public class TezDependentTaskCompletionEvent implements Writable {
   public void setTaskTrackerHttp(String taskTrackerHttp) {
     this.taskTrackerHttp = taskTrackerHttp;
   }
+  
+  /**
+   * Set the user payload
+   * @param userPayload
+   */
+  public void setUserPayload(byte[] userPayload) {
+    this.userPayload = userPayload;
+  }
     
   @Override
   public String toString(){
@@ -170,6 +205,7 @@ public class TezDependentTaskCompletionEvent implements Writable {
     
   @Override
   public boolean equals(Object o) {
+    // not counting userPayload as that is a piggyback mechanism
     if(o == null)
       return false;
     if(o.getClass().equals(this.getClass())) {
@@ -178,6 +214,7 @@ public class TezDependentTaskCompletionEvent implements Writable {
              && this.status.equals(event.getStatus())
              && this.taskAttemptId.equals(event.getTaskAttemptID()) 
              && this.taskRunTime == event.getTaskRunTime()
+             && this.dataSize == event.getDataSize()
              && this.taskTrackerHttp.equals(event.getTaskTrackerHttp());
     }
     return false;
@@ -196,6 +233,7 @@ public class TezDependentTaskCompletionEvent implements Writable {
     WritableUtils.writeString(out, taskTrackerHttp);
     WritableUtils.writeVInt(out, taskRunTime);
     WritableUtils.writeVInt(out, eventId);
+    WritableUtils.writeCompressedByteArray(out, userPayload);
   }
 
   @Override
@@ -206,7 +244,7 @@ public class TezDependentTaskCompletionEvent implements Writable {
     taskTrackerHttp = WritableUtils.readString(in);
     taskRunTime = WritableUtils.readVInt(in);
     eventId = WritableUtils.readVInt(in);
-    
+    userPayload = WritableUtils.readCompressedByteArray(in);
   }
   
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/EventFetcher.java
----------------------------------------------------------------------
diff --git a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/EventFetcher.java b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/EventFetcher.java
index e5a1c83..51e05af 100644
--- a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/EventFetcher.java
+++ b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/EventFetcher.java
@@ -38,7 +38,7 @@ class EventFetcher extends Thread {
   private final ShuffleScheduler scheduler;
   private int fromEventIdx = 0;
   private int maxEventsToFetch;
-  private ExceptionReporter exceptionReporter = null;
+  private Shuffle shuffle = null;
   
   private int maxMapRuntime = 0;
 
@@ -47,14 +47,14 @@ class EventFetcher extends Thread {
   public EventFetcher(TezTaskAttemptID reduce,
                       Master umbilical,
                       ShuffleScheduler scheduler,
-                      ExceptionReporter reporter,
+                      Shuffle shuffle,
                       int maxEventsToFetch) {
     setName("EventFetcher for fetching Map Completion Events");
     setDaemon(true);    
     this.reduce = reduce;
     this.umbilical = umbilical;
     this.scheduler = scheduler;
-    exceptionReporter = reporter;
+    this.shuffle = shuffle;
     this.maxEventsToFetch = maxEventsToFetch;
   }
 
@@ -93,7 +93,7 @@ class EventFetcher extends Thread {
     } catch (InterruptedException e) {
       return;
     } catch (Throwable t) {
-      exceptionReporter.reportException(t);
+      shuffle.reportException(t);
       return;
     }
   }
@@ -146,12 +146,13 @@ class EventFetcher extends Thread {
       // 3. Remove TIPFAILED maps from neededOutputs since we don't need their
       //    outputs at all.
       for (TezDependentTaskCompletionEvent event : events) {
+        byte[] userPayload = event.getUserPayload();
+        if(userPayload != null) {
+          shuffle.updateUserPayload(userPayload);
+        }
         switch (event.getStatus()) {
         case SUCCEEDED:
-          URI u = getBaseURI(event.getTaskTrackerHttp());
-          scheduler.addKnownMapOutput(u.getHost() + ":" + u.getPort(),
-              u.toString(),
-              event.getTaskAttemptID());
+          addMapHosts(event);
           numNewMaps ++;
           int duration = event.getTaskRunTime();
           if (duration > maxMapRuntime) {
@@ -178,7 +179,19 @@ class EventFetcher extends Thread {
     return numNewMaps;
   }
   
-  private URI getBaseURI(String url) {
+  private void addMapHosts(TezDependentTaskCompletionEvent event) {
+    int reduceRange = shuffle.getReduceRange();
+    for(int i=0; i<reduceRange; ++i) {
+      int partitionId = reduce.getTaskID().getId()+i;
+      URI u = getBaseURI(event.getTaskTrackerHttp(), partitionId);
+      scheduler.addKnownMapOutput(u.getHost() + ":" + u.getPort(),
+          partitionId,
+          u.toString(),
+          event.getTaskAttemptID());
+    }
+  }
+  
+  private URI getBaseURI(String url, int reduceId) {
     StringBuffer baseUrl = new StringBuffer(url);
     if (!url.endsWith("/")) {
       baseUrl.append("/");
@@ -191,7 +204,7 @@ class EventFetcher extends Thread {
 
     baseUrl.append(jobID);
     baseUrl.append("&reduce=");
-    baseUrl.append(reduce.getTaskID().getId());
+    baseUrl.append(reduceId);
     baseUrl.append("&map=");
     URI u = URI.create(baseUrl.toString());
     return u;

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Fetcher.java
----------------------------------------------------------------------
diff --git a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Fetcher.java b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Fetcher.java
index fe92dd1..0acceaf 100644
--- a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Fetcher.java
+++ b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Fetcher.java
@@ -78,10 +78,9 @@ class Fetcher extends Thread {
   private final MergeManager merger;
   private final ShuffleScheduler scheduler;
   private final ShuffleClientMetrics metrics;
-  private final ExceptionReporter exceptionReporter;
+  private final Shuffle shuffle;
   private final int id;
   private static int nextId = 0;
-  private final int reduce;
   
   private final int connectionTimeout;
   private final int readTimeout;
@@ -98,18 +97,17 @@ class Fetcher extends Thread {
   private static boolean sslShuffle;
   private static SSLFactory sslFactory;
 
-  public Fetcher(Configuration job, TezTaskAttemptID reduceId, 
+  public Fetcher(Configuration job, 
       ShuffleScheduler scheduler, MergeManager merger,
       TezTaskReporter reporter, ShuffleClientMetrics metrics,
-      ExceptionReporter exceptionReporter, SecretKey jobTokenSecret) {
+      Shuffle shuffle, SecretKey jobTokenSecret) {
     this.job = job;
     this.reporter = reporter;
     this.scheduler = scheduler;
     this.merger = merger;
     this.metrics = metrics;
-    this.exceptionReporter = exceptionReporter;
+    this.shuffle = shuffle;
     this.id = ++nextId;
-    this.reduce = reduceId.getTaskID().getId();
     this.jobTokenSecret = jobTokenSecret;
     ioErrs = reporter.getCounter(SHUFFLE_ERR_GRP_NAME,
         ShuffleErrors.IO_ERROR.toString());
@@ -182,7 +180,7 @@ class Fetcher extends Thread {
     } catch (InterruptedException ie) {
       return;
     } catch (Throwable t) {
-      exceptionReporter.reportException(t);
+      shuffle.reportException(t);
     }
   }
 
@@ -459,7 +457,9 @@ class Fetcher extends Thread {
       return false;
     }
     
-    if (forReduce != reduce) {
+    int reduceStartId = shuffle.getReduceStartId();
+    int reduceRange = shuffle.getReduceRange();
+    if (forReduce < reduceStartId || forReduce >= reduceStartId+reduceRange) {
       wrongReduceErrs.increment(1);
       LOG.warn(getName() + " data for the wrong reduce map: " +
                mapId + " len: " + compressedLength + " decomp len: " +

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/MapHost.java
----------------------------------------------------------------------
diff --git a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/MapHost.java b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/MapHost.java
index 544f3b5..24f7635 100644
--- a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/MapHost.java
+++ b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/MapHost.java
@@ -34,14 +34,30 @@ class MapHost {
   
   private State state = State.IDLE;
   private final String hostName;
+  private final int partitionId;
   private final String baseUrl;
+  private final String identifier;
   private List<TezTaskAttemptID> maps = new ArrayList<TezTaskAttemptID>();
   
-  public MapHost(String hostName, String baseUrl) {
+  public MapHost(int partitionId, String hostName, String baseUrl) {
+    this.partitionId = partitionId;
     this.hostName = hostName;
     this.baseUrl = baseUrl;
+    this.identifier = createIdentifier(hostName, partitionId);
   }
   
+  public static String createIdentifier(String hostName, int partitionId) {
+    return hostName + ":" + Integer.toString(partitionId);
+  }
+  
+  public String getIdentifier() {
+    return identifier;
+  }
+  
+  public int getPartitionId() {
+    return partitionId;
+  }
+
   public State getState() {
     return state;
   }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Shuffle.java
----------------------------------------------------------------------
diff --git a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Shuffle.java b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Shuffle.java
index 6fc5226..69dd036 100644
--- a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Shuffle.java
+++ b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/Shuffle.java
@@ -18,6 +18,7 @@
 package org.apache.tez.engine.common.shuffle.impl;
 
 import java.io.IOException;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -32,8 +33,10 @@ import org.apache.tez.common.TezEngineTaskContext;
 import org.apache.tez.common.TezJobConfig;
 import org.apache.tez.common.TezTaskReporter;
 import org.apache.tez.common.TezTaskStatus;
+import org.apache.tez.common.TezUtils;
 import org.apache.tez.common.counters.TezCounter;
 import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.engine.api.Processor;
 import org.apache.tez.engine.common.sort.impl.TezRawKeyValueIterator;
 
@@ -60,6 +63,9 @@ public class Shuffle implements ExceptionReporter {
   private final Progress copyPhase;
   private final Progress mergePhase;
   private final int tasksInDegree;
+  private final AtomicInteger reduceStartId;
+  private AtomicInteger reduceRange = new AtomicInteger(
+      TezJobConfig.TEZ_ENGINE_SHUFFLE_PARTITION_RANGE_DEFAULT);
   
   public Shuffle(TezEngineTaskContext taskContext,
                  RunningTaskContext runningTaskContext,
@@ -99,8 +105,14 @@ public class Shuffle implements ExceptionReporter {
     TezCounter mergedMapOutputsCounter =
         reporter.getCounter(TaskCounter.MERGED_MAP_OUTPUTS);
     
+    reduceStartId = new AtomicInteger( 
+        taskContext.getTaskAttemptId().getTaskID().getId()); 
+    LOG.info("Shuffle assigned reduce start id: " + reduceStartId.get()
+        + " with default reduce range: " + reduceRange.get());
+
     scheduler = 
-      new ShuffleScheduler(this.conf, tasksInDegree, runningTaskContext.getStatus(), 
+      new ShuffleScheduler(this.conf, tasksInDegree,
+                                runningTaskContext.getStatus(), 
                                 this, copyPhase, 
                                 shuffledMapsCounter, 
                                 reduceShuffleBytes, 
@@ -136,10 +148,9 @@ public class Shuffle implements ExceptionReporter {
             TezJobConfig.DEFAULT_TEZ_ENGINE_SHUFFLE_PARALLEL_COPIES);
     Fetcher[] fetchers = new Fetcher[numFetchers];
     for (int i=0; i < numFetchers; ++i) {
-      fetchers[i] = new Fetcher(conf, taskContext.getTaskAttemptId(), 
-                                     scheduler, merger, 
-                                     reporter, metrics, this, 
-                                     runningTaskContext.getJobTokenSecret());
+      fetchers[i] = new Fetcher(conf, scheduler,
+          merger, reporter, metrics, this,
+          runningTaskContext.getJobTokenSecret());
       fetchers[i].start();
     }
     
@@ -190,7 +201,15 @@ public class Shuffle implements ExceptionReporter {
     
     return kvIter;
   }
-
+  
+  public int getReduceStartId() {
+    return reduceStartId.get();
+  }
+  
+  public int getReduceRange() {
+    return reduceRange.get();
+  }
+  
   public synchronized void reportException(Throwable t) {
     if (throwable == null) {
       throwable = t;
@@ -210,4 +229,31 @@ public class Shuffle implements ExceptionReporter {
       super(msg, t);
     }
   }
+  
+  public void updateUserPayload(byte[] userPayload) throws IOException {
+    if(userPayload == null) {
+      return;
+    }
+    Configuration conf = TezUtils.createConfFromUserPayload(userPayload);
+    int reduceRange = conf.getInt(
+        TezJobConfig.TEZ_ENGINE_SHUFFLE_PARTITION_RANGE,
+        TezJobConfig.TEZ_ENGINE_SHUFFLE_PARTITION_RANGE_DEFAULT);
+    setReduceRange(reduceRange);
+  }
+  
+  private void setReduceRange(int range) {
+    if (range == reduceRange.get()) {
+      return;
+    }
+    if (reduceRange.compareAndSet(
+        TezJobConfig.TEZ_ENGINE_SHUFFLE_PARTITION_RANGE_DEFAULT, range)) {
+      LOG.info("Reduce range set to: " + range);
+    } else {
+      TezUncheckedException e = 
+          new TezUncheckedException("Reduce range can be set only once.");
+      reportException(e);
+      throw e; 
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/ShuffleScheduler.java
----------------------------------------------------------------------
diff --git a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/ShuffleScheduler.java b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/ShuffleScheduler.java
index 1e9d358..6bd18ef 100644
--- a/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/ShuffleScheduler.java
+++ b/tez-engine/src/main/java/org/apache/tez/engine/common/shuffle/impl/ShuffleScheduler.java
@@ -31,6 +31,7 @@ import java.util.concurrent.DelayQueue;
 import java.util.concurrent.Delayed;
 import java.util.concurrent.TimeUnit;
 
+import org.apache.commons.lang.mutable.MutableInt;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
@@ -54,7 +55,7 @@ class ShuffleScheduler {
   private static final long INITIAL_PENALTY = 10000;
   private static final float PENALTY_GROWTH_RATE = 1.3f;
   
-  private final Map<TezTaskID, Boolean> finishedMaps;
+  private final Map<TezTaskID, MutableInt> finishedMaps;
   private final int tasksInDegree;
   private int remainingMaps;
   private Map<String, MapHost> mapLocations = new HashMap<String, MapHost>();
@@ -69,7 +70,7 @@ class ShuffleScheduler {
   private final Map<String,IntWritable> hostFailures = 
     new HashMap<String,IntWritable>();
   private final TezTaskStatus status;
-  private final ExceptionReporter reporter;
+  private final Shuffle shuffle;
   private final int abortFailureLimit;
   private final Progress progress;
   private final TezCounter shuffledMapsCounter;
@@ -91,7 +92,7 @@ class ShuffleScheduler {
   public ShuffleScheduler(Configuration conf,
                           int tasksInDegree,
                           TezTaskStatus status,
-                          ExceptionReporter reporter,
+                          Shuffle shuffle,
                           Progress progress,
                           TezCounter shuffledMapsCounter,
                           TezCounter reduceShuffleBytes,
@@ -99,8 +100,8 @@ class ShuffleScheduler {
     this.tasksInDegree = tasksInDegree;
     abortFailureLimit = Math.max(30, tasksInDegree / 10);
     remainingMaps = tasksInDegree;
-    finishedMaps = new HashMap<TezTaskID, Boolean>(remainingMaps);
-    this.reporter = reporter;
+    finishedMaps = new HashMap<TezTaskID, MutableInt>(remainingMaps);
+    this.shuffle = shuffle;
     this.status = status;
     this.progress = progress;
     this.shuffledMapsCounter = shuffledMapsCounter;
@@ -133,10 +134,11 @@ class ShuffleScheduler {
     
     if (!isFinishedTaskTrue(taskId)) {
       output.commit();
-      setFinishedTaskTrue(taskId);
-      shuffledMapsCounter.increment(1);
-      if (--remainingMaps == 0) {
-        notifyAll();
+      if(incrementTaskCopyAndCheckCompletion(taskId)) {
+        shuffledMapsCounter.increment(1);
+        if (--remainingMaps == 0) {
+          notifyAll();
+        }
       }
 
       // update the status
@@ -184,7 +186,7 @@ class ShuffleScheduler {
       try {
         throw new IOException(failures + " failures downloading " + mapId);
       } catch (IOException ie) {
-        reporter.reportException(ie);
+        shuffle.reportException(ie);
       }
     }
     
@@ -256,7 +258,7 @@ class ShuffleScheduler {
       LOG.fatal("Shuffle failed with too many fetch failures " +
       "and insufficient progress!");
       String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
-      reporter.reportException(new IOException(errorMsg));
+      shuffle.reportException(new IOException(errorMsg));
     }
 
   }
@@ -271,13 +273,16 @@ class ShuffleScheduler {
     }
   }
   
-  public synchronized void addKnownMapOutput(String hostName, 
+  public synchronized void addKnownMapOutput(String hostName,
+                                             int partitionId,
                                              String hostUrl,
                                              TezTaskAttemptID mapId) {
-    MapHost host = mapLocations.get(hostName);
+    String identifier = MapHost.createIdentifier(hostName, partitionId);
+    MapHost host = mapLocations.get(identifier);
     if (host == null) {
-      host = new MapHost(hostName, hostUrl);
-      mapLocations.put(hostName, host);
+      host = new MapHost(partitionId, hostName, hostUrl);
+      assert identifier.equals(host.getIdentifier());
+      mapLocations.put(identifier, host);
     }
     host.addKnownMap(mapId);
 
@@ -427,7 +432,7 @@ class ShuffleScheduler {
       } catch (InterruptedException ie) {
         return;
       } catch (Throwable t) {
-        reporter.reportException(t);
+        shuffle.reportException(t);
       }
     }
   }
@@ -444,15 +449,33 @@ class ShuffleScheduler {
   }
   
   void setFinishedTaskTrue(TezTaskID taskId) {
-    finishedMaps.put(taskId, true);
+    synchronized(finishedMaps) {
+      finishedMaps.put(taskId, new MutableInt(shuffle.getReduceRange()));
+    }
+  }
+  
+  boolean incrementTaskCopyAndCheckCompletion(TezTaskID mapTaskId) {
+    synchronized(finishedMaps) {
+      MutableInt result = finishedMaps.get(mapTaskId);
+      if(result == null) {
+        result = new MutableInt(0);
+        finishedMaps.put(mapTaskId, result);
+      }
+      result.increment();
+      return isFinishedTaskTrue(mapTaskId);
+    }
   }
   
   boolean isFinishedTaskTrue(TezTaskID taskId) {
-    Boolean result = finishedMaps.get(taskId);
-    if(result == null) {
-      return false;
+    synchronized (finishedMaps) {
+      MutableInt result = finishedMaps.get(taskId);
+      if(result == null) {
+        return false;
+      }
+      if (result.intValue() == shuffle.getReduceRange()) {
+        return true;
+      }
+      return false;      
     }
-    
-    return result.booleanValue();
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-mapreduce-examples/src/main/java/org/apache/tez/mapreduce/examples/OrderedWordCount.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce-examples/src/main/java/org/apache/tez/mapreduce/examples/OrderedWordCount.java b/tez-mapreduce-examples/src/main/java/org/apache/tez/mapreduce/examples/OrderedWordCount.java
index 78eca75..3edb73a 100644
--- a/tez-mapreduce-examples/src/main/java/org/apache/tez/mapreduce/examples/OrderedWordCount.java
+++ b/tez-mapreduce-examples/src/main/java/org/apache/tez/mapreduce/examples/OrderedWordCount.java
@@ -36,8 +36,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.util.GenericOptionsParser;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.client.api.YarnClient;
-import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl;
 import org.apache.tez.client.TezClient;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
@@ -116,15 +114,15 @@ public class OrderedWordCount {
 
     // Set reducer class for intermediate reduce
     conf.setClass(MultiStageMRConfigUtil.getPropertyNameForIntermediateStage(1,
-        "mapreduce.job.reduce.class"), IntSumReducer.class, Reducer.class);
+        MRJobConfig.REDUCE_CLASS_ATTR), IntSumReducer.class, Reducer.class);
     // Set reducer output key class
     conf.setClass(MultiStageMRConfigUtil.getPropertyNameForIntermediateStage(1,
-        "mapreduce.map.output.key.class"), IntWritable.class, Object.class);
+        MRJobConfig.MAP_OUTPUT_KEY_CLASS), IntWritable.class, Object.class);
     // Set reducer output value class
     conf.setClass(MultiStageMRConfigUtil.getPropertyNameForIntermediateStage(1,
-        "mapreduce.map.output.value.class"), Text.class, Object.class);
+        MRJobConfig.MAP_OUTPUT_VALUE_CLASS), Text.class, Object.class);
     conf.setInt(MultiStageMRConfigUtil.getPropertyNameForIntermediateStage(1,
-        "mapreduce.job.reduces"), 2);
+        MRJobConfig.NUM_REDUCES), 2);
 
     @SuppressWarnings("deprecation")
     Job job = new Job(conf, "orderedwordcount");
@@ -143,10 +141,6 @@ public class OrderedWordCount {
     FileInputFormat.addInputPath(job, new Path(otherArgs[0]));
     FileOutputFormat.setOutputPath(job, new Path(otherArgs[1]));
 
-    YarnClient yarnClient = new YarnClientImpl();
-    yarnClient.init(conf);
-    yarnClient.start();
-
     TezClient tezClient = new TezClient(new TezConfiguration(conf));
 
     job.submit();

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/e368ede8/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRHelpers.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRHelpers.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRHelpers.java
index 9b9c6b5..759e173 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRHelpers.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRHelpers.java
@@ -37,8 +37,6 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.DataInputBuffer;
-import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapreduce.InputFormat;
 import org.apache.hadoop.mapreduce.Job;
@@ -57,10 +55,10 @@ import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.apache.hadoop.yarn.util.Apps;
 import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.tez.common.TezUtils;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
 
-import com.google.common.base.Preconditions;
 
 public class MRHelpers {
 
@@ -439,23 +437,14 @@ public class MRHelpers {
   @Unstable
   public static byte[] createUserPayloadFromConf(Configuration conf)
       throws IOException {
-    Preconditions.checkNotNull(conf, "Configuration must be specified");
-    DataOutputBuffer dob = new DataOutputBuffer();
-    conf.write(dob);
-    return dob.getData();
+    return TezUtils.createUserPayloadFromConf(conf);
   }
 
   @LimitedPrivate("Hive, Pig")
   @Unstable
   public static Configuration createConfFromUserPayload(byte[] bb)
       throws IOException {
-    // TODO Avoid copy ?
-    Preconditions.checkNotNull(bb, "Bytes must be specified");
-    DataInputBuffer dib = new DataInputBuffer();
-    dib.reset(bb, 0, bb.length);
-    Configuration conf = new Configuration(false);
-    conf.readFields(dib);
-    return conf;
+    return TezUtils.createConfFromUserPayload(bb);
   }
 
   /**


Mime
View raw message