tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rbalamo...@apache.org
Subject git commit: TEZ-1649. ShuffleVertexManager auto reduce parallelism can cause jobs to hang indefinitely (Rajesh Balamohan)
Date Tue, 14 Oct 2014 08:33:53 GMT
Repository: tez
Updated Branches:
  refs/heads/master 4b5b20e23 -> 0ad4a4b88


TEZ-1649. ShuffleVertexManager auto reduce parallelism can cause jobs to hang indefinitely  (Rajesh Balamohan)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/0ad4a4b8
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/0ad4a4b8
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/0ad4a4b8

Branch: refs/heads/master
Commit: 0ad4a4b889028eb1a31a719b22c33d1feb1721af
Parents: 4b5b20e
Author: Rajesh Balamohan <rbalamohan@apache.org>
Authored: Tue Oct 14 13:44:37 2014 +0530
Committer: Rajesh Balamohan <rbalamohan@apache.org>
Committed: Tue Oct 14 13:53:37 2014 +0530

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../vertexmanager/ShuffleVertexManager.java     | 177 ++++---
 .../vertexmanager/TestShuffleVertexManager.java | 500 ++++++++++++++++---
 3 files changed, 545 insertions(+), 133 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/0ad4a4b8/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index d8cc2f2..7b6fbf7 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -22,6 +22,7 @@ ALL CHANGES:
   TEZ-1646. Add support for augmenting classpath via configs.
   TEZ-1647. Issue with caching of events in VertexManager::onRootVertexInitialized.
   TEZ-1470. Recovery fails due to TaskAttemptFinishedEvent being recorded multiple times for the same task.
+  TEZ-1649. ShuffleVertexManager auto reduce parallelism can cause jobs to hang indefinitely.
 
 Release 0.5.1: 2014-10-02
 

http://git-wip-us.apache.org/repos/asf/tez/blob/0ad4a4b8/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
index eeb3676..9f3d3f9 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
@@ -20,19 +20,12 @@ package org.apache.tez.dag.library.vertexmanager;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import com.google.common.base.Predicate;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 import com.google.protobuf.ByteString;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import javax.annotation.Nullable;
-
+import com.google.protobuf.InvalidProtocolBufferException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.classification.InterfaceAudience.Public;
@@ -43,12 +36,12 @@ import org.apache.tez.dag.api.EdgeManagerPlugin;
 import org.apache.tez.dag.api.EdgeManagerPluginContext;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.api.UserPayload;
 import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
-import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
 import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
 import org.apache.tez.runtime.api.Event;
@@ -58,9 +51,14 @@ import org.apache.tez.runtime.api.events.VertexManagerEvent;
 import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto;
 import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
 
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.protobuf.InvalidProtocolBufferException;
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.BitSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Starts scheduling tasks when number of completed source tasks crosses 
@@ -124,15 +122,32 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
   boolean enableAutoParallelism = false;
   boolean parallelismDetermined = false;
 
-  int totalNumSourceTasks = 0;
-  int numSourceTasksCompleted = 0;
+  int totalNumBipartiteSourceTasks = 0;
+  int numBipartiteSourceTasksCompleted = 0;
   int numVertexManagerEventsReceived = 0;
   List<Integer> pendingTasks;
   int totalTasksToSchedule = 0;
   
-  Map<String, Set<Integer>> bipartiteSources = Maps.newHashMap();
+  //Track source vertex and its finished tasks
+  private final Map<String, SourceVertexInfo> srcVertexInfo = Maps.newHashMap();
+  boolean sourceVerticesScheduled = false;
+  @VisibleForTesting
+  int bipartiteSources = 0;
   long completedSourceTasksOutputSize = 0;
 
+  class SourceVertexInfo {
+    EdgeProperty edgeProperty;
+    int numFinishedTasks;
+    BitSet finishedTaskSet;
+
+    SourceVertexInfo(EdgeProperty edgeProperty) {
+      this.edgeProperty = edgeProperty;
+      if (edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER) {
+        finishedTaskSet = new BitSet();
+      }
+    }
+  }
+
   public ShuffleVertexManager(VertexManagerPluginContext context) {
     super(context);
   }
@@ -313,7 +328,7 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
     updateSourceTaskCount();
     
     LOG.info("OnVertexStarted vertex: " + getContext().getVertexName() +
-             " with " + totalNumSourceTasks + " source tasks and " + 
+             " with " + totalNumBipartiteSourceTasks + " source tasks and " +
              totalTasksToSchedule + " pending tasks");
     
     if (completions != null) {
@@ -330,15 +345,22 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
   @Override
   public void onSourceTaskCompleted(String srcVertexName, Integer srcTaskId) {
     updateSourceTaskCount();
-    Set<Integer> completedSourceTasks = bipartiteSources.get(srcVertexName);
-    if (completedSourceTasks != null) {
-      // duplicate notifications tracking
-      if (completedSourceTasks.add(srcTaskId)) {
-        // source task has completed
-        ++numSourceTasksCompleted;
+    SourceVertexInfo srcInfo = srcVertexInfo.get(srcVertexName);
+    srcInfo.numFinishedTasks++;
+
+    if (srcInfo.edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER) {
+      //handle duplicate events for bipartite sources
+      BitSet completedSourceTasks = srcInfo.finishedTaskSet;
+      if (completedSourceTasks != null) {
+        // duplicate notifications tracking
+        if (!completedSourceTasks.get(srcTaskId)) {
+          completedSourceTasks.set(srcTaskId);
+          // source task has completed
+          ++numBipartiteSourceTasksCompleted;
+        }
       }
-      schedulePendingTasks();
     }
+    schedulePendingTasks();
   }
   
   @Override
@@ -371,14 +393,23 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
     }
     totalTasksToSchedule = pendingTasks.size();
   }
-  
+
+  Iterable<Map.Entry<String, SourceVertexInfo>> getBipartiteInfo() {
+    return Iterables.filter(srcVertexInfo.entrySet(), new Predicate<Map.Entry<String,SourceVertexInfo>>() {
+      public boolean apply(Map.Entry<String, SourceVertexInfo> input) {
+        return (input.getValue().edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER);
+      }
+    });
+  }
+
   void updateSourceTaskCount() {
     // track source vertices
     int numSrcTasks = 0;
-    for(String vertex : bipartiteSources.keySet()) {
-      numSrcTasks += getContext().getVertexNumTasks(vertex);
+    Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo();
+    for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
+      numSrcTasks += getContext().getVertexNumTasks(entry.getKey());
     }
-    totalNumSourceTasks = numSrcTasks;
+    totalNumBipartiteSourceTasks = numSrcTasks;
   }
 
   /**
@@ -387,7 +418,7 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
    */
   @VisibleForTesting
   boolean determineParallelismAndApply() {
-    if(numSourceTasksCompleted == 0) {
+    if(numBipartiteSourceTasksCompleted == 0) {
       return true;
     }
     
@@ -404,19 +435,19 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
      */
     boolean canDetermineParallelismLater = (completedSourceTasksOutputSize <
         desiredTaskInputDataSize)
-        && (numSourceTasksCompleted < (totalNumSourceTasks * slowStartMaxSrcCompletionFraction));
+        && (numBipartiteSourceTasksCompleted < (totalNumBipartiteSourceTasks * slowStartMaxSrcCompletionFraction));
     if (canDetermineParallelismLater) {
       LOG.info("Defer scheduling tasks; vertex=" + getContext().getVertexName()
-          + ", totalNumSourceTasks=" + totalNumSourceTasks
+          + ", totalNumBipartiteSourceTasks=" + totalNumBipartiteSourceTasks
           + ", completedSourceTasksOutputSize=" + completedSourceTasksOutputSize
           + ", numVertexManagerEventsReceived=" + numVertexManagerEventsReceived
-          + ", numSourceTasksCompleted=" + numSourceTasksCompleted + ", maxThreshold="
-          + (totalNumSourceTasks * slowStartMaxSrcCompletionFraction));
+          + ", numBipartiteSourceTasksCompleted=" + numBipartiteSourceTasksCompleted + ", maxThreshold="
+          + (totalNumBipartiteSourceTasks * slowStartMaxSrcCompletionFraction));
       return false;
     }
 
     long expectedTotalSourceTasksOutputSize =
-        (totalNumSourceTasks * completedSourceTasksOutputSize) / numVertexManagerEventsReceived;
+        (totalNumBipartiteSourceTasks * completedSourceTasksOutputSize) / numVertexManagerEventsReceived;
 
     int desiredTaskParallelism = 
         (int)(
@@ -450,14 +481,16 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
         + " based on actual output: " + completedSourceTasksOutputSize
         + " from " + numVertexManagerEventsReceived + " vertex manager events. "
         + " desiredTaskInputSize: " + desiredTaskInputDataSize + " max slow start tasks:" +
-        (totalNumSourceTasks * slowStartMaxSrcCompletionFraction) + " num sources completed:" +
-        numSourceTasksCompleted);
+        (totalNumBipartiteSourceTasks * slowStartMaxSrcCompletionFraction) + " num sources completed:" +
+        numBipartiteSourceTasksCompleted);
           
     if(finalTaskParallelism < currentParallelism) {
       // final parallelism is less than actual parallelism
       Map<String, EdgeManagerPluginDescriptor> edgeManagers =
-          new HashMap<String, EdgeManagerPluginDescriptor>(bipartiteSources.size());
-      for(String vertex : bipartiteSources.keySet()) {
+          new HashMap<String, EdgeManagerPluginDescriptor>(bipartiteSources);
+      Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo();
+      for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
+        String vertex = entry.getKey();
         // use currentParallelism for numSourceTasks to maintain original state
         // for the source tasks
         CustomShuffleEdgeManagerConfig edgeManagerConfig =
@@ -500,14 +533,44 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
     }
     getContext().scheduleVertexTasks(scheduledTasks);
   }
-  
+
+  /**
+   * Verify whether each of the source vertices have completed at least 1 task
+   *
+   * @return boolean
+   */
+  boolean canScheduleTasks() {
+    for(Map.Entry<String, SourceVertexInfo> entry : srcVertexInfo.entrySet()) {
+      String sourceVertex = entry.getKey();
+      int completedTasks = entry.getValue().numFinishedTasks;
+      int numSourceTasks = getContext().getVertexNumTasks(sourceVertex);
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("vertex=" + getContext().getVertexName() + ", srcVertex="+ entry.getKey() +
+                ", finishedTasks=" + entry.getValue().numFinishedTasks);
+      }
+      if (numSourceTasks > 0 && completedTasks <= 0) {
+        return false;
+      }
+    }
+    sourceVerticesScheduled = true;
+    return sourceVerticesScheduled;
+  }
+
   void schedulePendingTasks() {
     int numPendingTasks = pendingTasks.size();
     if (numPendingTasks == 0) {
       return;
     }
-    
-    if (numSourceTasksCompleted == totalNumSourceTasks && numPendingTasks > 0) {
+
+    if (!sourceVerticesScheduled && !canScheduleTasks()) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Defer scheduling tasks for vertex:" + getContext().getVertexName()
+            + " as one task needs to be completed per source vertex");
+      }
+      return;
+    }
+
+    if (numBipartiteSourceTasksCompleted == totalNumBipartiteSourceTasks && numPendingTasks > 0) {
       LOG.info("All source tasks assigned. " +
           "Ramping up " + numPendingTasks + 
           " remaining tasks for vertex: " + getContext().getVertexName());
@@ -516,12 +579,12 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
     }
 
     float completedSourceTaskFraction = 0f;
-    if (totalNumSourceTasks != 0) { // support for 0 source tasks
-      completedSourceTaskFraction = (float)numSourceTasksCompleted/totalNumSourceTasks;
+    if (totalNumBipartiteSourceTasks != 0) { // support for 0 source tasks
+      completedSourceTaskFraction = (float) numBipartiteSourceTasksCompleted / totalNumBipartiteSourceTasks;
     } else {
       completedSourceTaskFraction = 1;
     }
-    
+
     // start scheduling when source tasks completed fraction is more than min.
     // linearly increase the number of scheduled tasks such that all tasks are 
     // scheduled when source tasks completed fraction reaches max
@@ -538,23 +601,19 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
       }
     }
     
-    if (tasksFractionToSchedule > 1) {
-      tasksFractionToSchedule = 1;
-    } else if (tasksFractionToSchedule < 0) {
-      tasksFractionToSchedule = 0;
-    }
-    
+    tasksFractionToSchedule = Math.max(0, Math.min(1, tasksFractionToSchedule));
+
     int numTasksToSchedule = 
         ((int)(tasksFractionToSchedule * totalTasksToSchedule) - 
          (totalTasksToSchedule - numPendingTasks));
     
     if (numTasksToSchedule > 0) {
-      // numTasksToSchedule can be -ve if numSourceTasksCompleted does not 
+      // numTasksToSchedule can be -ve if numBipartiteSourceTasksCompleted does not
       // does not increase monotonically
       LOG.info("Scheduling " + numTasksToSchedule + " tasks for vertex: " + 
                getContext().getVertexName() + " with totalTasks: " +
-               totalTasksToSchedule + ". " + numSourceTasksCompleted + 
-               " source tasks completed out of " + totalNumSourceTasks + 
+               totalTasksToSchedule + ". " + numBipartiteSourceTasksCompleted +
+               " source tasks completed out of " + totalNumBipartiteSourceTasks +
                ". SourceTaskCompletedFraction: " + completedSourceTaskFraction + 
                " min: " + slowStartMinSrcCompletionFraction + 
                " max: " + slowStartMaxSrcCompletionFraction);
@@ -608,12 +667,12 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
     
     Map<String, EdgeProperty> inputs = getContext().getInputVertexEdgeProperties();
     for(Map.Entry<String, EdgeProperty> entry : inputs.entrySet()) {
+      srcVertexInfo.put(entry.getKey(), new SourceVertexInfo(entry.getValue()));
       if (entry.getValue().getDataMovementType() == DataMovementType.SCATTER_GATHER) {
-        String vertex = entry.getKey();
-        bipartiteSources.put(vertex, new HashSet<Integer>());
+        bipartiteSources++;
       }
     }
-    if(bipartiteSources.isEmpty()) {
+    if(bipartiteSources == 0) {
       throw new TezUncheckedException("Atleast 1 bipartite source should exist");
     }
     // dont track the source tasks here since those tasks may themselves be

http://git-wip-us.apache.org/repos/asf/tez/blob/0ad4a4b8/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
index 6d065fc..d967122 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
@@ -18,14 +18,7 @@
 
 package org.apache.tez.dag.library.vertexmanager;
 
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-
+import com.google.common.collect.Maps;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.tez.common.ReflectionUtils;
 import org.apache.tez.common.TezUtils;
@@ -33,12 +26,12 @@ import org.apache.tez.dag.api.EdgeManagerPlugin;
 import org.apache.tez.dag.api.EdgeManagerPluginContext;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.api.UserPayload;
 import org.apache.tez.dag.api.VertexLocationHint;
-import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
 import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
 import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
@@ -50,9 +43,24 @@ import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
-import com.google.common.collect.Maps;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
 
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.anyList;
+import static org.mockito.Mockito.anyMap;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class TestShuffleVertexManager {
 
@@ -140,10 +148,8 @@ public class TestShuffleVertexManager {
 
     // check initialization
     manager = createManager(conf, mockContext, 0.1f, 0.1f);
-    Assert.assertTrue(manager.bipartiteSources.size() == 2);
-    Assert.assertTrue(manager.bipartiteSources.containsKey(mockSrcVertexId1));
-    Assert.assertTrue(manager.bipartiteSources.containsKey(mockSrcVertexId2));
-    
+    Assert.assertTrue(manager.bipartiteSources == 2);
+
     final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
     doAnswer(new Answer() {
       public Object answer(InvocationOnMock invocation) {
@@ -158,7 +164,7 @@ public class TestShuffleVertexManager {
     
     final Map<String, EdgeManagerPlugin> newEdgeManagers =
         new HashMap<String, EdgeManagerPlugin>();
-    
+
     doAnswer(new Answer() {
       public Object answer(InvocationOnMock invocation) {
           when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(2);
@@ -203,12 +209,15 @@ public class TestShuffleVertexManager {
           return null;
       }}).when(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());
     
-    // source vertices have 0 tasks. immediate start of all managed tasks
+    // source vertices have 0 tasks.
     when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(0);
     when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(0);
     when(mockContext.getVertexNumTasks(mockSrcVertexId3)).thenReturn(1);
 
     manager.onVertexStarted(null);
+    Assert.assertFalse(manager.pendingTasks.isEmpty());
+    Assert.assertTrue(scheduledTasks.size() == 0); // no tasks scheduled
+    manager.onSourceTaskCompleted(mockSrcVertexId3, 0);
     Assert.assertTrue(manager.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 4); // all tasks scheduled
     scheduledTasks.clear();
@@ -223,14 +232,21 @@ public class TestShuffleVertexManager {
     manager = createManager(conf, mockContext, 0.1f, 0.1f);
     manager.onVertexStarted(null);
     Assert.assertTrue(manager.pendingTasks.size() == 4); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     manager.onVertexManagerEventReceived(vmEvent);
+
+    //1 task of every source vertex needs to be completed.  Until then, we defer scheduling
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
-    // managedVertex tasks reduced
     verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap(), anyMap());
+    Assert.assertEquals(4, manager.pendingTasks.size()); // no task scheduled
+
+    //1 task of every source vertex needs to be completed.  Until then, we defer scheduling
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
+    verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap(), anyMap());
+    manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
     Assert.assertEquals(4, scheduledTasks.size());
-    Assert.assertEquals(1, manager.numSourceTasksCompleted);
+    Assert.assertEquals(2, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(5000L, manager.completedSourceTasksOutputSize);
 
     /**
@@ -246,8 +262,8 @@ public class TestShuffleVertexManager {
     manager = createManager(conf, mockContext, 0.01f, 0.75f);
     manager.onVertexStarted(null);
     Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
-    Assert.assertEquals(4, manager.totalNumSourceTasks);
-    Assert.assertEquals(0, manager.numSourceTasksCompleted);
+    Assert.assertEquals(4, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
 
     //First task in src1 completed with small payload
     manager.onVertexManagerEventReceived(vmEvent); //small payload
@@ -255,7 +271,7 @@ public class TestShuffleVertexManager {
     Assert.assertTrue(manager.determineParallelismAndApply() == false);
     Assert.assertEquals(4, manager.pendingTasks.size());
     Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
-    Assert.assertEquals(1, manager.numSourceTasksCompleted);
+    Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(1, manager.numVertexManagerEventsReceived);
     Assert.assertEquals(1L, manager.completedSourceTasksOutputSize);
 
@@ -266,7 +282,7 @@ public class TestShuffleVertexManager {
     Assert.assertTrue(manager.determineParallelismAndApply() == false);
     Assert.assertEquals(4, manager.pendingTasks.size());
     Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
-    Assert.assertEquals(1, manager.numSourceTasksCompleted);
+    Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(2, manager.numVertexManagerEventsReceived);
     Assert.assertEquals(2L, manager.completedSourceTasksOutputSize);
 
@@ -281,9 +297,11 @@ public class TestShuffleVertexManager {
         anyMap(),
         anyMap());
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
+    //Need to have 1 task completed from all sources for this vertex
+    manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertEquals(1, manager.pendingTasks.size());
     Assert.assertEquals(1, scheduledTasks.size());
-    Assert.assertEquals(2, manager.numSourceTasksCompleted);
+    Assert.assertEquals(2, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(3, manager.numVertexManagerEventsReceived);
     Assert.assertEquals(1202L, manager.completedSourceTasksOutputSize);
 
@@ -301,8 +319,8 @@ public class TestShuffleVertexManager {
     manager = createManager(conf, mockContext, 0.0f, 0.2f);
     manager.onVertexStarted(null);
     Assert.assertEquals(40, manager.pendingTasks.size()); // no tasks scheduled
-    Assert.assertEquals(40, manager.totalNumSourceTasks);
-    Assert.assertEquals(0, manager.numSourceTasksCompleted);
+    Assert.assertEquals(40, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
     //send 7 events with payload size as 100
     for(int i=0;i<7;i++) {
       manager.onVertexManagerEventReceived(vmEvent); //small payload
@@ -314,7 +332,8 @@ public class TestShuffleVertexManager {
     }
     //send 8th event with payload size as 100
     manager.onVertexManagerEventReceived(vmEvent);
-    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(8));
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(8));
+    manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     //Since max threshold (40 * 0.2 = 8) is met, vertex manager should determine parallelism
     verify(mockContext, times(1)).setVertexParallelism(eq(4), any(VertexLocationHint.class),
         anyMap(),
@@ -334,28 +353,28 @@ public class TestShuffleVertexManager {
     manager = createManager(conf, mockContext, 0.5f, 0.5f);
     manager.onVertexStarted(null);
     Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
-    Assert.assertEquals(4, manager.totalNumSourceTasks);
+    Assert.assertEquals(4, manager.totalNumBipartiteSourceTasks);
     // task completion from non-bipartite stage does nothing
     manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
-    Assert.assertEquals(4, manager.totalNumSourceTasks);
-    Assert.assertEquals(0, manager.numSourceTasksCompleted);
+    Assert.assertEquals(4, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
     manager.onVertexManagerEventReceived(vmEvent);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
     Assert.assertEquals(4, manager.pendingTasks.size());
     Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
-    Assert.assertEquals(1, manager.numSourceTasksCompleted);
+    Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(1, manager.numVertexManagerEventsReceived);
     Assert.assertEquals(500L, manager.completedSourceTasksOutputSize);
     // ignore duplicate completion
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
     Assert.assertEquals(4, manager.pendingTasks.size());
     Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
-    Assert.assertEquals(1, manager.numSourceTasksCompleted);
+    Assert.assertEquals(1, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(500L, manager.completedSourceTasksOutputSize);
     
     manager.onVertexManagerEventReceived(vmEvent);
-    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
     // managedVertex tasks reduced
     verify(mockContext, times(2)).setVertexParallelism(eq(2), any(VertexLocationHint.class),
         anyMap(),
@@ -366,7 +385,7 @@ public class TestShuffleVertexManager {
     Assert.assertEquals(2, scheduledTasks.size());
     Assert.assertTrue(scheduledTasks.contains(new Integer(0)));
     Assert.assertTrue(scheduledTasks.contains(new Integer(1)));
-    Assert.assertEquals(2, manager.numSourceTasksCompleted);
+    Assert.assertEquals(2, manager.numBipartiteSourceTasksCompleted);
     Assert.assertEquals(2, manager.numVertexManagerEventsReceived);
     Assert.assertEquals(1000L, manager.completedSourceTasksOutputSize);
     
@@ -460,10 +479,8 @@ public class TestShuffleVertexManager {
     
     // check initialization
     manager = createManager(conf, mockContext, 0.1f, 0.1f);
-    Assert.assertTrue(manager.bipartiteSources.size() == 2);
-    Assert.assertTrue(manager.bipartiteSources.containsKey(mockSrcVertexId1));
-    Assert.assertTrue(manager.bipartiteSources.containsKey(mockSrcVertexId2));
-        
+    Assert.assertTrue(manager.bipartiteSources == 2);
+
     final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
     doAnswer(new Answer() {
       public Object answer(InvocationOnMock invocation) {
@@ -507,9 +524,13 @@ public class TestShuffleVertexManager {
     // source vertex have some tasks. min, max == 0
     manager = createManager(conf, mockContext, 0, 0);
     manager.onVertexStarted(null);
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     Assert.assertTrue(manager.totalTasksToSchedule == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 0);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
+    //Atleast 1 task should be complete in all sources
+    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
+    manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
     
@@ -517,112 +538,443 @@ public class TestShuffleVertexManager {
     manager = createManager(conf, mockContext, 0.25f, 0.25f);
     manager.onVertexStarted(null);
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     // task completion from non-bipartite stage does nothing
     manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 0);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
+    //1 task has to be completed in every source vertex. Until then, defer scheduling
+    Assert.assertFalse(manager.pendingTasks.isEmpty());
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 1);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
     
     // min, max > 0 and min == max == absolute max 1.0
     manager = createManager(conf, mockContext, 1.0f, 1.0f);
     manager.onVertexStarted(null);
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     // task completion from non-bipartite stage does nothing
     manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 0);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 1);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 1);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.size() == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 2);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 3);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);
     
     // min, max > 0 and min == max
     manager = createManager(conf, mockContext, 1.0f, 1.0f);
     manager.onVertexStarted(null);
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     // task completion from non-bipartite stage does nothing
     manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 0);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 0);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 1);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 1);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.size() == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 2);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 3);
-    Assert.assertTrue(manager.numSourceTasksCompleted == 3);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.isEmpty());
     Assert.assertTrue(scheduledTasks.size() == 3); // all tasks scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);
     
     // min, max > and min < max
     manager = createManager(conf, mockContext, 0.25f, 0.75f);
     manager.onVertexStarted(null);
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
-    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.size() == 2);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 2);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
     // completion of same task again should not get counted
-    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.size() == 2);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 2);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 0);
     Assert.assertTrue(scheduledTasks.size() == 2); // 2 tasks scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 3);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
     scheduledTasks.clear();
-    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1)); // we are done. no action
+    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1)); // we are done. no action
     Assert.assertTrue(manager.pendingTasks.size() == 0);
     Assert.assertTrue(scheduledTasks.size() == 0); // no task scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);
 
     // min, max > and min < max
     manager = createManager(conf, mockContext, 0.25f, 1.0f);
     manager.onVertexStarted(null);
     Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
-    Assert.assertTrue(manager.totalNumSourceTasks == 4);
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 4);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
-    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
+    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.size() == 2);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 2);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 2);
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
     Assert.assertTrue(manager.pendingTasks.size() == 1);
     Assert.assertTrue(scheduledTasks.size() == 1); // 1 task scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 3);
-    manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(1));
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 3);
+    manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
     Assert.assertTrue(manager.pendingTasks.size() == 0);
     Assert.assertTrue(scheduledTasks.size() == 1); // no task scheduled
-    Assert.assertTrue(manager.numSourceTasksCompleted == 4);
+    Assert.assertTrue(manager.numBipartiteSourceTasksCompleted == 4);
+
+  }
+
+
+  /**
+   * Tasks should be scheduled only when at least 1 task from each source vertex is complete
+   */
+  @Test(timeout = 5000)
+  public void test_Tez1649_with_scatter_gather_edges() {
+    Configuration conf = new Configuration();
+    conf.setBoolean(
+        ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+        true);
+    conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 1000L);
+    ShuffleVertexManager manager = null;
+
+    HashMap<String, EdgeProperty> mockInputVertices_R2 = new HashMap<String, EdgeProperty>();
+    String r1 = "R1";
+    EdgeProperty eProp1 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.SCATTER_GATHER,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String m2 = "M2";
+    EdgeProperty eProp2 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.SCATTER_GATHER,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String m3 = "M3";
+    EdgeProperty eProp3 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.SCATTER_GATHER,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+
+    final String mockManagedVertexId_R2 = "R2";
+    mockInputVertices_R2.put(r1, eProp1);
+    mockInputVertices_R2.put(m2, eProp2);
+    mockInputVertices_R2.put(m3, eProp3);
+
+    final VertexManagerPluginContext mockContext_R2 = mock(VertexManagerPluginContext.class);
+    when(mockContext_R2.getInputVertexEdgeProperties()).thenReturn(mockInputVertices_R2);
+    when(mockContext_R2.getVertexName()).thenReturn(mockManagedVertexId_R2);
+    when(mockContext_R2.getVertexNumTasks(mockManagedVertexId_R2)).thenReturn(3);
+    when(mockContext_R2.getVertexNumTasks(r1)).thenReturn(3);
+    when(mockContext_R2.getVertexNumTasks(m2)).thenReturn(3);
+    when(mockContext_R2.getVertexNumTasks(m3)).thenReturn(3);
+
+    final Map<String, EdgeManagerPlugin> edgeManagerR2 =
+        new HashMap<String, EdgeManagerPlugin>();
+    doAnswer(new Answer() {
+      public Object answer(InvocationOnMock invocation) {
+        when(mockContext_R2.getVertexNumTasks(mockManagedVertexId_R2)).thenReturn(2);
+        edgeManagerR2.clear();
+        for (Entry<String, EdgeManagerPluginDescriptor> entry :
+            ((Map<String, EdgeManagerPluginDescriptor>)invocation.getArguments()[2]).entrySet()) {
+
+
+          final UserPayload userPayload = entry.getValue().getUserPayload();
+          EdgeManagerPluginContext emContext = new EdgeManagerPluginContext() {
+            @Override
+            public UserPayload getUserPayload() {
+              return userPayload == null ? null : userPayload;
+            }
+
+            @Override
+            public String getSourceVertexName() {
+              return null;
+            }
+
+            @Override
+            public String getDestinationVertexName() {
+              return null;
+            }
+
+            @Override
+            public int getSourceVertexNumTasks() {
+              return 2;
+            }
+
+            @Override
+            public int getDestinationVertexNumTasks() {
+              return 2;
+            }
+          };
+          EdgeManagerPlugin edgeManager = ReflectionUtils
+              .createClazzInstance(entry.getValue().getClassName(),
+                  new Class[]{EdgeManagerPluginContext.class}, new Object[]{emContext});
+          edgeManager.initialize();
+          edgeManagerR2.put(entry.getKey(), edgeManager);
+        }
+        return null;
+      }}).when(mockContext_R2).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());
+
+    ByteBuffer payload =
+        VertexManagerEventPayloadProto.newBuilder().setOutputSize(50L).build().toByteString().asReadOnlyByteBuffer();
+    VertexManagerEvent vmEvent = VertexManagerEvent.create("Vertex", payload);
+
+    // check initialization
+    manager = createManager(conf, mockContext_R2, 0.001f, 0.001f);
+    Assert.assertTrue(manager.bipartiteSources == 3);
+
+    final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
+    doAnswer(new Answer() {
+      public Object answer(InvocationOnMock invocation) {
+        Object[] args = invocation.getArguments();
+        scheduledTasks.clear();
+        List<TaskWithLocationHint> tasks = (List<TaskWithLocationHint>)args[0];
+        for (TaskWithLocationHint task : tasks) {
+          scheduledTasks.add(task.getTaskIndex());
+        }
+        return null;
+      }}).when(mockContext_R2).scheduleVertexTasks(anyList());
+
+    manager.onVertexStarted(null);
+    manager.onVertexManagerEventReceived(vmEvent);
+    Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(9, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
 
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 9);
+
+    //Send events for all tasks of m3.
+    manager.onSourceTaskCompleted(m3, new Integer(0));
+    manager.onSourceTaskCompleted(m3, new Integer(1));
+    manager.onSourceTaskCompleted(m3, new Integer(2));
+
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 9);
+
+    //Send an event for m2. But still we need to wait for at least 1 event from r1.
+    manager.onSourceTaskCompleted(m2, new Integer(0));
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 9);
+
+    //Ensure that setVertexParallelism is not called for R2.
+    verify(mockContext_R2, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class),
+        anyMap(),
+        anyMap());
+
+    //1 Task completes in R1.  Now things in R2 should be able to proceed.
+    manager.onSourceTaskCompleted(r1, new Integer(0));
+    verify(mockContext_R2, times(1)).setVertexParallelism(eq(1), any(VertexLocationHint.class),
+        anyMap(),
+        anyMap());
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 3);
+
+    //try with zero task vertices
+    scheduledTasks.clear();
+    when(mockContext_R2.getInputVertexEdgeProperties()).thenReturn(mockInputVertices_R2);
+    when(mockContext_R2.getVertexName()).thenReturn(mockManagedVertexId_R2);
+    when(mockContext_R2.getVertexNumTasks(mockManagedVertexId_R2)).thenReturn(3);
+    when(mockContext_R2.getVertexNumTasks(r1)).thenReturn(0);
+    when(mockContext_R2.getVertexNumTasks(m2)).thenReturn(0);
+    when(mockContext_R2.getVertexNumTasks(m3)).thenReturn(3);
+
+    manager = createManager(conf, mockContext_R2, 0.001f, 0.001f);
+    manager.onVertexStarted(null);
+    Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(3, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
+
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);
+
+    //Send events for all tasks of m3.
+    manager.onSourceTaskCompleted(m3, new Integer(0));
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 3);
+  }
+
+  @Test(timeout = 5000)
+  public void test_Tez1649_with_mixed_edges() {
+    Configuration conf = new Configuration();
+    conf.setBoolean(
+        ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+        true);
+    conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE, 1000L);
+    ShuffleVertexManager manager = null;
+
+    HashMap<String, EdgeProperty> mockInputVertices =
+        new HashMap<String, EdgeProperty>();
+    String r1 = "R1";
+    EdgeProperty eProp1 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.SCATTER_GATHER,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String m2 = "M2";
+    EdgeProperty eProp2 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.BROADCAST,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String m3 = "M3";
+    EdgeProperty eProp3 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.BROADCAST,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+
+    final String mockManagedVertexId = "R2";
+
+    mockInputVertices.put(r1, eProp1);
+    mockInputVertices.put(m2, eProp2);
+    mockInputVertices.put(m3, eProp3);
+
+    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
+    when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(r1)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(m2)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(m3)).thenReturn(3);
+
+    // check initialization
+    manager = createManager(conf, mockContext, 0.001f, 0.001f);
+    Assert.assertTrue(manager.bipartiteSources == 1);
+
+    final HashSet<Integer> scheduledTasks = new HashSet<Integer>();
+    doAnswer(new Answer() {
+      public Object answer(InvocationOnMock invocation) {
+        Object[] args = invocation.getArguments();
+        scheduledTasks.clear();
+        List<TaskWithLocationHint> tasks = (List<TaskWithLocationHint>)args[0];
+        for (TaskWithLocationHint task : tasks) {
+          scheduledTasks.add(task.getTaskIndex());
+        }
+        return null;
+      }}).when(mockContext).scheduleVertexTasks(anyList());
+
+    manager.onVertexStarted(null);
+    Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(3, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
+
+    //Send events for 2 tasks of r1.
+    manager.onSourceTaskCompleted(r1, new Integer(0));
+    manager.onSourceTaskCompleted(r1, new Integer(1));
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);
+
+    //Send an event for m2.
+    manager.onSourceTaskCompleted(m2, new Integer(0));
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);
+
+    //Send an event for m2.
+    manager.onSourceTaskCompleted(m3, new Integer(0));
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 3);
+
+    //Scenario when numBipartiteSourceTasksCompleted == totalNumBipartiteSourceTasks.
+    //Still, wait for a task to be completed from other edges
+    scheduledTasks.clear();
+    manager = createManager(conf, mockContext, 0.001f, 0.001f);
+    manager.onVertexStarted(null);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
+    when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(r1)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(m2)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(m3)).thenReturn(3);
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == 3);
+
+    manager.onSourceTaskCompleted(r1, new Integer(0));
+    manager.onSourceTaskCompleted(r1, new Integer(1));
+    manager.onSourceTaskCompleted(r1, new Integer(2));
+    //Tasks from non-scatter edges of m2 and m3 are not complete.
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    manager.onSourceTaskCompleted(m2, new Integer(0));
+    manager.onSourceTaskCompleted(m3, new Integer(0));
+    //Got an event from other edges. Schedule all
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 3);
+
+
+    //try with a zero task vertex (with non-scatter-gather edges)
+    scheduledTasks.clear();
+    manager = createManager(conf, mockContext, 0.001f, 0.001f);
+    manager.onVertexStarted(null);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
+    when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(r1)).thenReturn(3); //scatter gather
+    when(mockContext.getVertexNumTasks(m2)).thenReturn(0); //broadcast
+    when(mockContext.getVertexNumTasks(m3)).thenReturn(3); //broadcast
+
+    manager = createManager(conf, mockContext, 0.001f, 0.001f);
+    manager.onVertexStarted(null);
+    Assert.assertEquals(3, manager.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(3, manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
+
+    //Send 2 events for tasks of r1.
+    manager.onSourceTaskCompleted(r1, new Integer(0));
+    manager.onSourceTaskCompleted(r1, new Integer(1));
+    Assert.assertTrue(manager.pendingTasks.size() == 3); // no tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 0);
+
+    manager.onSourceTaskCompleted(m3, new Integer(1));
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 3);
+
+    //try with all zero task vertices in non-SG edges
+    scheduledTasks.clear();
+    manager = createManager(conf, mockContext, 0.001f, 0.001f);
+    manager.onVertexStarted(null);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
+    when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
+    when(mockContext.getVertexNumTasks(r1)).thenReturn(3); //scatter gather
+    when(mockContext.getVertexNumTasks(m2)).thenReturn(0); //broadcast
+    when(mockContext.getVertexNumTasks(m3)).thenReturn(0); //broadcast
+
+    //Send 1 events for tasks of r1.
+    manager.onSourceTaskCompleted(r1, new Integer(0));
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == 3);
   }
 
-  private ShuffleVertexManager createManager(Configuration conf, 
+  private ShuffleVertexManager createManager(Configuration conf,
       VertexManagerPluginContext context, float min, float max) {
     conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, min);
     conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, max);


Mime
View raw message