tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ss...@apache.org
Subject [2/2] tez git commit: TEZ-3395. Refactor ShuffleVertexManager to make parts of it re-usable in other plugins. Contributed by Ming Ma.
Date Fri, 19 Aug 2016 19:02:06 GMT
TEZ-3395. Refactor ShuffleVertexManager to make parts of it re-usable in other plugins. Contributed by Ming Ma.


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/9ca2d5be
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/9ca2d5be
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/9ca2d5be

Branch: refs/heads/master
Commit: 9ca2d5be601ca5d40808270b668a9018ee78473f
Parents: 1468457
Author: Siddharth Seth <sseth@apache.org>
Authored: Fri Aug 19 12:01:47 2016 -0700
Committer: Siddharth Seth <sseth@apache.org>
Committed: Fri Aug 19 12:01:47 2016 -0700

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 tez-runtime-library/findbugs-exclude.xml        |  10 +
 .../vertexmanager/ShuffleVertexManager.java     | 856 +++++--------------
 .../vertexmanager/ShuffleVertexManagerBase.java | 768 +++++++++++++++++
 .../vertexmanager/TestShuffleVertexManager.java | 114 ++-
 5 files changed, 1040 insertions(+), 709 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/9ca2d5be/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index a8e7080..10c48f1 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -7,6 +7,7 @@ INCOMPATIBLE CHANGES
 
 ALL CHANGES:
 
+  TEZ-3395. Refactor ShuffleVertexManager to make parts of it re-usable in other plugins.
   TEZ-3413. ConcurrentModificationException in HistoryEventTimelineConversion for AppLaunchedEvent.
   TEZ-3352. MRInputHelpers getStringProperty() should not fail if property value is null.
   TEZ-3409. Log dagId along with other information when submitting a dag.

http://git-wip-us.apache.org/repos/asf/tez/blob/9ca2d5be/tez-runtime-library/findbugs-exclude.xml
----------------------------------------------------------------------
diff --git a/tez-runtime-library/findbugs-exclude.xml b/tez-runtime-library/findbugs-exclude.xml
index b7bb43a..4e15edc 100644
--- a/tez-runtime-library/findbugs-exclude.xml
+++ b/tez-runtime-library/findbugs-exclude.xml
@@ -131,4 +131,14 @@
     </Or>
   </Match>
 
+  <Match>
+    <Class name="org.apache.tez.dag.library.vertexmanager.ShuffleVertexManagerBase"/>
+    <Or>
+      <Field name="numBipartiteSourceTasksCompleted"/>
+      <Field name="totalNumBipartiteSourceTasks"/>
+      <Field name="totalTasksToSchedule"/>
+    </Or>
+    <Bug pattern="IS2_INCONSISTENT_SYNC"/>
+  </Match>
+
 </FindBugsFilter>

http://git-wip-us.apache.org/repos/asf/tez/blob/9ca2d5be/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
index c5278dd..0bb2753 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
@@ -18,65 +18,39 @@
 
 package org.apache.tez.dag.library.vertexmanager;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
-import com.google.common.base.Predicate;
-import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.InvalidProtocolBufferException;
 
-import org.apache.tez.common.TezCommonUtils;
-import org.apache.tez.runtime.library.utils.DATA_RANGE_IN_MB;
-import org.roaringbitmap.RoaringBitmap;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.apache.hadoop.classification.InterfaceAudience.Public;
-import org.apache.hadoop.classification.InterfaceStability.Evolving;
-import org.apache.hadoop.conf.Configuration;
 import org.apache.tez.common.TezUtils;
 import org.apache.tez.dag.api.EdgeManagerPluginContext;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeManagerPluginOnDemand;
-import org.apache.tez.dag.api.EdgeProperty;
-import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
-import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.api.UserPayload;
-import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
 import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
 import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
-import org.apache.tez.dag.api.event.VertexState;
-import org.apache.tez.dag.api.event.VertexStateUpdate;
-import org.apache.tez.runtime.api.Event;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.classification.InterfaceStability.Evolving;
+import org.apache.hadoop.conf.Configuration;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
-import org.apache.tez.runtime.api.TaskIdentifier;
 import org.apache.tez.runtime.api.events.DataMovementEvent;
 import org.apache.tez.runtime.api.events.InputReadErrorEvent;
-import org.apache.tez.runtime.api.events.VertexManagerEvent;
 import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.ShuffleEdgeManagerConfigPayloadProto;
-import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
 
 import javax.annotation.Nullable;
 
-import java.io.ByteArrayInputStream;
-import java.io.DataInputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
-import java.util.BitSet;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.EnumSet;
-import java.util.HashMap;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Starts scheduling tasks when number of completed source tasks crosses 
@@ -85,14 +59,42 @@ import java.util.concurrent.atomic.AtomicBoolean;
  */
 @Public
 @Evolving
-public class ShuffleVertexManager extends VertexManagerPlugin {
-  
+public class ShuffleVertexManager extends ShuffleVertexManagerBase {
+
+  private static final Logger LOG =
+      LoggerFactory.getLogger(ShuffleVertexManager.class);
+
+  /**
+   * The desired size of input per task. Parallelism will be changed to meet this criteria
+   */
+  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE =
+      "tez.shuffle-vertex-manager.desired-task-input-size";
+  public static final long
+      TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT = 100 * MB;
+
+  /**
+   * Enables automatic parallelism determination for the vertex. Based on input data
+   * statisitics the parallelism is decreased to a desired level.
+   */
+  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL =
+      "tez.shuffle-vertex-manager.enable.auto-parallel";
+  public static final boolean
+      TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT = false;
+
+  /**
+   * Automatic parallelism determination will not decrease parallelism below this value
+   */
+  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM =
+      "tez.shuffle-vertex-manager.min-task-parallelism";
+  public static final int TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT = 1;
+
+
   /**
    * In case of a ScatterGather connection, the fraction of source tasks which
    * should complete before tasks for the current vertex are scheduled
    */
-  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION = 
-                                    "tez.shuffle-vertex-manager.min-src-fraction";
+  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION =
+      "tez.shuffle-vertex-manager.min-src-fraction";
   public static final float TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT = 0.25f;
 
   /**
@@ -102,109 +104,62 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
    * min-fraction and max-fraction. Defaults to the greater of the default value
    * or tez.shuffle-vertex-manager.min-src-fraction.
    */
-  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION = 
-                                      "tez.shuffle-vertex-manager.max-src-fraction";
+  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION =
+      "tez.shuffle-vertex-manager.max-src-fraction";
   public static final float TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT = 0.75f;
-  
-  /**
-   * Enables automatic parallelism determination for the vertex. Based on input data
-   * statisitics the parallelism is decreased to a desired level.
-   */
-  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL = 
-                                      "tez.shuffle-vertex-manager.enable.auto-parallel";
-  public static final boolean
-    TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT = false;
-  
-  /**
-   * The desired size of input per task. Parallelism will be changed to meet this criteria
-   */
-  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE = 
-                                     "tez.shuffle-vertex-manager.desired-task-input-size";
-  public static final long
-    TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT = 1024*1024*100L;
-
-  /**
-   * Automatic parallelism determination will not decrease parallelism below this value
-   */
-  public static final String TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM = 
-                                    "tez.shuffle-vertex-manager.min-task-parallelism";
-  public static final int TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT = 1;
 
-  
-  private static final Logger LOG = 
-                   LoggerFactory.getLogger(ShuffleVertexManager.class);
-
-  float slowStartMinSrcCompletionFraction;
-  float slowStartMaxSrcCompletionFraction;
-  long desiredTaskInputDataSize = 1024*1024*100L;
-  int minTaskParallelism = 1;
-  boolean enableAutoParallelism = false;
-  boolean parallelismDetermined = false;
-
-  int totalNumBipartiteSourceTasks = 0;
-  int numBipartiteSourceTasksCompleted = 0;
-  int numVertexManagerEventsReceived = 0;
-  List<PendingTaskInfo> pendingTasks = Lists.newLinkedList();
-  List<VertexManagerEvent> pendingVMEvents = Lists.newLinkedList();
-  int totalTasksToSchedule = 0;
-  private AtomicBoolean onVertexStartedDone = new AtomicBoolean(false);
-  
-  private Set<TaskIdentifier> taskWithVmEvents = Sets.newHashSet();
-  
-  //Track source vertex and its finished tasks
-  private final Map<String, SourceVertexInfo> srcVertexInfo = Maps.newConcurrentMap();
-  boolean sourceVerticesScheduled = false;
-  @VisibleForTesting
-  int bipartiteSources = 0;
-  long completedSourceTasksOutputSize = 0;
-  List<VertexStateUpdate> pendingStateUpdates = Lists.newArrayList();
+  ShuffleVertexManagerConfig mgrConfig;
 
   private int[][] targetIndexes;
   private int basePartitionRange;
   private int remainderRangeForLastShuffler;
-  @VisibleForTesting
-  long[] stats; //approximate amount of data to be fetched
-
-  static class SourceVertexInfo {
-    EdgeProperty edgeProperty;
-    boolean vertexIsConfigured;
-    BitSet finishedTaskSet;
-    int numTasks;
-    int numVMEventsReceived;
-    long outputSize;
-
-    SourceVertexInfo(EdgeProperty edgeProperty) {
-      this.edgeProperty = edgeProperty;
-      finishedTaskSet = new BitSet();
-    }
-    
-    int getNumTasks() {
-      return numTasks;
-    }
-    
-    int getNumCompletedTasks() {
-      return finishedTaskSet.cardinality();
-    }
-  }
 
-  static class PendingTaskInfo {
-    private int index;
-    private long outputStats;
 
-    public PendingTaskInfo(int index) {
-      this.index = index;
-    }
+  public ShuffleVertexManager(VertexManagerPluginContext context) {
+    super(context);
+  }
 
-    public String toString() {
-      return "[index=" + index + ", outputStats=" + outputStats + "]";
+  static class ShuffleVertexManagerConfig extends ShuffleVertexManagerBaseConfig {
+    final int minTaskParallelism;
+    public ShuffleVertexManagerConfig(final boolean enableAutoParallelism,
+        final long desiredTaskInputDataSize, final float slowStartMinFraction,
+        final float slowStartMaxFraction, final int minTaskParallelism) {
+      super(enableAutoParallelism, desiredTaskInputDataSize,
+          slowStartMinFraction, slowStartMaxFraction);
+      this.minTaskParallelism = minTaskParallelism;
+      LOG.info("minTaskParallelism {}", this.minTaskParallelism);
+    }
+    int getMinTaskParallelism() {
+      return minTaskParallelism;
     }
   }
 
-  public ShuffleVertexManager(VertexManagerPluginContext context) {
-    super(context);
+  @Override
+  ShuffleVertexManagerBaseConfig initConfiguration() {
+    float slowStartMinFraction = conf.getFloat(
+        TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION,
+        TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
+
+    mgrConfig = new ShuffleVertexManagerConfig(
+        conf.getBoolean(
+            TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+            TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT),
+        conf.getLong(
+            TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
+            TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT),
+        slowStartMinFraction,
+        conf.getFloat(
+            TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION,
+            Math.max(slowStartMinFraction,
+            TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT)),
+        Math.max(1, conf
+            .getInt(TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM,
+            TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT)));
+    return mgrConfig;
   }
 
-  static int[] createIndices(int partitionRange, int taskIndex, int offSetPerTask) {
+  static int[] createIndices(int partitionRange, int taskIndex,
+      int offSetPerTask) {
     int startIndex = taskIndex * offSetPerTask;
     int[] indices = new int[partitionRange];
     for (int currentIndex = 0; currentIndex < partitionRange; ++currentIndex) {
@@ -285,7 +240,8 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
           sourceTaskIndex * partitionRange 
           + sourceIndex % partitionRange;
 
-      destinationTaskAndInputIndices.put(destinationTaskIndex, Collections.singletonList(targetIndex));
+      destinationTaskAndInputIndices.put(
+          destinationTaskIndex, Collections.singletonList(targetIndex));
     }
 
     @Override
@@ -484,314 +440,88 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
     }
   }
 
-  
-  @Override
-  public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions) {
-    // examine edges after vertex started because until then these may not have been defined
-    Map<String, EdgeProperty> inputs = getContext().getInputVertexEdgeProperties();
-    for(Map.Entry<String, EdgeProperty> entry : inputs.entrySet()) {
-      srcVertexInfo.put(entry.getKey(), new SourceVertexInfo(entry.getValue()));
-      // TODO what if derived class has already called this
-      // register for status update from all source vertices
-      getContext().registerForVertexStateUpdates(entry.getKey(),
-          EnumSet.of(VertexState.CONFIGURED));
-      if (entry.getValue().getDataMovementType() == DataMovementType.SCATTER_GATHER) {
-        bipartiteSources++;
-      }
-    }
-    if(bipartiteSources == 0) {
-      throw new TezUncheckedException("Atleast 1 bipartite source should exist");
-    }
-
-    for (VertexStateUpdate stateUpdate : pendingStateUpdates) {
-      handleVertexStateUpdate(stateUpdate);
-    }
-    pendingStateUpdates.clear();
-
-    // track the tasks in this vertex
-    updatePendingTasks();
-
-    for (VertexManagerEvent vmEvent : pendingVMEvents) {
-      handleVertexManagerEvent(vmEvent);
-    }
-    pendingVMEvents.clear();
-
-    LOG.info("OnVertexStarted vertex: " + getContext().getVertexName() +
-             " with " + totalNumBipartiteSourceTasks + " source tasks and " +
-             totalTasksToSchedule + " pending tasks");
-    
-    if (completions != null) {
-      for (TaskAttemptIdentifier attempt : completions) {
-        onSourceTaskCompleted(attempt);
-      }
-    }
-    onVertexStartedDone.set(true);
-    // for the special case when source has 0 tasks or min fraction == 0
-    schedulePendingTasks();
-  }
-
-
-  @Override
-  public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
-    String srcVertexName = attempt.getTaskIdentifier().getVertexIdentifier().getName();
-    int srcTaskId = attempt.getTaskIdentifier().getIdentifier();
-    SourceVertexInfo srcInfo = srcVertexInfo.get(srcVertexName);
-    if (srcInfo.vertexIsConfigured) {
-      Preconditions.checkState(srcTaskId < srcInfo.numTasks,  
-          "Received completion for srcTaskId " + srcTaskId + " but Vertex: " + srcVertexName +
-          " has only " + srcInfo.numTasks + " tasks");
-    }
-    //handle duplicate events and count task completions from all source vertices
-    BitSet completedSourceTasks = srcInfo.finishedTaskSet;
-    // duplicate notifications tracking
-    if (!completedSourceTasks.get(srcTaskId)) {
-      completedSourceTasks.set(srcTaskId);
-      // source task has completed
-      if (srcInfo.edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER) {
-        numBipartiteSourceTasksCompleted++;
-      }
-    }
-    schedulePendingTasks();
-  }
-
-  @VisibleForTesting
-  void parseDetailedPartitionStats(List<Integer> partitionStats) {
-    Preconditions.checkState(stats != null, "Stats should be initialized");
-    for (int i = 0; i< partitionStats.size(); i++) {
-      stats[i] += partitionStats.get(i);
-    }
-  }
-
-  @VisibleForTesting
-  void parsePartitionStats(RoaringBitmap partitionStats) {
-    Preconditions.checkState(stats != null, "Stats should be initialized");
-    Iterator<Integer> it = partitionStats.iterator();
-    final DATA_RANGE_IN_MB[] RANGES = DATA_RANGE_IN_MB.values();
-    final int RANGE_LEN = RANGES.length;
-    while (it.hasNext()) {
-      int pos = it.next();
-      int index = ((pos) / RANGE_LEN);
-      int rangeIndex = ((pos) % RANGE_LEN);
-      //Add to aggregated stats and normalize to DATA_RANGE_IN_MB.
-      if (RANGES[rangeIndex].getSizeInMB() > 0) {
-        stats[index] += RANGES[rangeIndex].getSizeInMB();
-      }
-    }
-  }
-
-
-  @Override
-  public synchronized void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
-    if (onVertexStartedDone.get()) {
-      // internal data structures have been initialized - so handle the events directly
-      handleVertexManagerEvent(vmEvent);
-    } else {
-      // save this event for processing after vertex starts
-      pendingVMEvents.add(vmEvent);
-    }
-  }
-
-  private void handleVertexManagerEvent(VertexManagerEvent vmEvent) {
-    // currently events from multiple attempts of the same task can be ignored because
-    // their output will be the same.
-    TaskIdentifier producerTask = vmEvent.getProducerAttemptIdentifier().getTaskIdentifier();
-    if (!taskWithVmEvents.add(producerTask)) {
-      LOG.info("Ignoring vertex manager event from: " + producerTask);
-      return;
-    }
-
-    String vName = producerTask.getVertexIdentifier().getName();
-    SourceVertexInfo srcInfo = srcVertexInfo.get(vName);
-    Preconditions.checkState(srcInfo != null, "Unknown vmEvent from " + producerTask);
-
-    numVertexManagerEventsReceived++;
-
-    long sourceTaskOutputSize = 0;
-    if (vmEvent.getUserPayload() != null) {
-      // save output size
-      VertexManagerEventPayloadProto proto;
-      try {
-        proto = VertexManagerEventPayloadProto.parseFrom(ByteString.copyFrom(vmEvent.getUserPayload()));
-      } catch (InvalidProtocolBufferException e) {
-        throw new TezUncheckedException(e);
-      }
-      sourceTaskOutputSize = proto.getOutputSize();
-
-      if (proto.hasPartitionStats()) {
-        try {
-          RoaringBitmap partitionStats = new RoaringBitmap();
-          ByteString compressedPartitionStats = proto.getPartitionStats();
-          byte[] rawData = TezCommonUtils.decompressByteStringToByteArray(compressedPartitionStats);
-          ByteArrayInputStream bin = new ByteArrayInputStream(rawData);
-          partitionStats.deserialize(new DataInputStream(bin));
-
-          parsePartitionStats(partitionStats);
-
-        } catch (IOException e) {
-          throw new TezUncheckedException(e);
-        }
-      } else if (proto.hasDetailedPartitionStats()) {
-        List<Integer> detailedPartitionStats = proto.getDetailedPartitionStats().getSizeInMbList();
-        parseDetailedPartitionStats(detailedPartitionStats);
-      }
-
-      srcInfo.numVMEventsReceived++;
-      srcInfo.outputSize += sourceTaskOutputSize;
-      completedSourceTasksOutputSize += sourceTaskOutputSize;
-    }
-    
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("For attempt: " + vmEvent.getProducerAttemptIdentifier()
-          + " received info of output size: " + sourceTaskOutputSize
-          + " vertex numEventsReceived: " + srcInfo.numVMEventsReceived
-          + " vertex output size: " + srcInfo.outputSize
-          + " total numEventsReceived: " + numVertexManagerEventsReceived
-          + " total output size: " + completedSourceTasksOutputSize);
-    }
-  }
-
-
-  void updatePendingTasks() {
-    int tasks = getContext().getVertexNumTasks(getContext().getVertexName());
-    if (tasks == pendingTasks.size() || tasks <= 0) {
-      return;
-    }
-    pendingTasks.clear();
-    for (int i = 0; i < tasks; ++i) {
-      pendingTasks.add(new PendingTaskInfo(i));
-    }
-    totalTasksToSchedule = pendingTasks.size();
-    if (stats == null) {
-      stats = new long[totalTasksToSchedule]; // TODO lost previous data
-    }
-  }
-
-  Iterable<Map.Entry<String, SourceVertexInfo>> getBipartiteInfo() {
-    return Iterables.filter(srcVertexInfo.entrySet(), new Predicate<Map.Entry<String,SourceVertexInfo>>() {
-      public boolean apply(Map.Entry<String, SourceVertexInfo> input) {
-        return (input.getValue().edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER);
-      }
-    });
-  }
-
-  /**
-   * Compute optimal parallelism needed for the job
-   * @return true (if parallelism is determined), false otherwise
-   */
-  @VisibleForTesting
-  boolean determineParallelismAndApply(float minSourceVertexCompletedTaskFraction) {
-    if(numVertexManagerEventsReceived == 0) {
-      if (totalNumBipartiteSourceTasks > 0) {
-        return true;
-      }
-    }
-    
+  ReconfigVertexParams computeRouting() {
     int currentParallelism = pendingTasks.size();
-    /**
-     * When overall completed output size is not even equal to
-     * desiredTaskInputSize, we can wait for some more data to be available to determine
-     * better parallelism until max.fraction is reached.  min.fraction is just a hint to the
-     * framework and need not be honored strictly in this case.
-     */
-    boolean canDetermineParallelismLater = (completedSourceTasksOutputSize <
-        desiredTaskInputDataSize)
-        && (minSourceVertexCompletedTaskFraction < slowStartMaxSrcCompletionFraction);
-    if (canDetermineParallelismLater) {
-      LOG.info("Defer scheduling tasks; vertex=" + getContext().getVertexName()
-          + ", totalNumBipartiteSourceTasks=" + totalNumBipartiteSourceTasks
-          + ", completedSourceTasksOutputSize=" + completedSourceTasksOutputSize
-          + ", numVertexManagerEventsReceived=" + numVertexManagerEventsReceived
-          + ", numBipartiteSourceTasksCompleted=" + numBipartiteSourceTasksCompleted
-          + ", minSourceVertexCompletedTaskFraction=" + minSourceVertexCompletedTaskFraction);
-      return false;
-    }
 
     // Change this to use per partition stats for more accuracy TEZ-2962.
-    // Instead of aggregating overall size and then dividing equally - coalesce partitions until 
+    // Instead of aggregating overall size and then dividing equally - coalesce partitions until
     // desired per partition size is achieved.
     long expectedTotalSourceTasksOutputSize = 0;
     for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) {
       SourceVertexInfo srcInfo = vInfo.getValue();
       if (srcInfo.numTasks > 0 && srcInfo.numVMEventsReceived > 0) {
         // this assumes that 1 vmEvent is received per completed task - TEZ-2961
-        expectedTotalSourceTasksOutputSize += 
+        expectedTotalSourceTasksOutputSize +=
             (srcInfo.numTasks * srcInfo.outputSize) / srcInfo.numVMEventsReceived;
       }
     }
 
-    LOG.info("Expected output: " + expectedTotalSourceTasksOutputSize + " based on actual output: "
-        + completedSourceTasksOutputSize + " from " + numVertexManagerEventsReceived + " vertex manager events. "
-        + " desiredTaskInputSize: " + desiredTaskInputDataSize + " max slow start tasks:"
-        + (totalNumBipartiteSourceTasks * slowStartMaxSrcCompletionFraction) + " num sources completed:"
-        + numBipartiteSourceTasksCompleted);
-
-    int desiredTaskParallelism = 
-        (int)(
-            (expectedTotalSourceTasksOutputSize+desiredTaskInputDataSize-1)/
-            desiredTaskInputDataSize);
-    if(desiredTaskParallelism < minTaskParallelism) {
-      desiredTaskParallelism = minTaskParallelism;
+    LOG.info("Expected output: {} based on actual output: {} from {} vertex " +
+        "manager events. desiredTaskInputSize: {} max slow start tasks: {} " +
+        " num sources completed: {}", expectedTotalSourceTasksOutputSize,
+        completedSourceTasksOutputSize, numVertexManagerEventsReceived,
+        config.getDesiredTaskInputDataSize(),
+        (totalNumBipartiteSourceTasks * config.getMaxFraction()),
+        numBipartiteSourceTasksCompleted);
+
+    int desiredTaskParallelism =
+        (int)((expectedTotalSourceTasksOutputSize +
+            config.getDesiredTaskInputDataSize() - 1) /
+                config.getDesiredTaskInputDataSize());
+    if(desiredTaskParallelism < mgrConfig.getMinTaskParallelism()) {
+      desiredTaskParallelism = mgrConfig.getMinTaskParallelism();
     }
 
     if(desiredTaskParallelism >= currentParallelism) {
-      LOG.info("Not reducing auto parallelism for vertex: " + getContext().getVertexName()
-          + " since the desired parallelism of " + desiredTaskParallelism
-          + " is greater than or equal to the current parallelism of " + pendingTasks.size());
-      return true;
+      LOG.info("Not reducing auto parallelism for vertex: {}"
+          + " since the desired parallelism of {} is greater than or equal"
+          + " to the current parallelism of {}", getContext().getVertexName(),
+          desiredTaskParallelism, pendingTasks.size());
+      return null;
     }
 
     // most shufflers will be assigned this range
     basePartitionRange = currentParallelism/desiredTaskParallelism;
-    
     if (basePartitionRange <= 1) {
       // nothing to do if range is equal 1 partition. shuffler does it by default
-      LOG.info("Not reducing auto parallelism for vertex: " + getContext().getVertexName()
-          + " by less than half since combining two inputs will potentially break the desired task input size of "
-          + desiredTaskInputDataSize);
-      return true;
+      LOG.info("Not reducing auto parallelism for vertex: {} by less than"
+          + " half since combining two inputs will potentially break the"
+          + " desired task input size of {}", getContext().getVertexName(),
+          config.getDesiredTaskInputDataSize());
+      return null;
     }
-    
     int numShufflersWithBaseRange = currentParallelism / basePartitionRange;
     remainderRangeForLastShuffler = currentParallelism % basePartitionRange;
-    
+
     int finalTaskParallelism = (remainderRangeForLastShuffler > 0) ?
-          (numShufflersWithBaseRange + 1) : (numShufflersWithBaseRange);
-
-    LOG.info("Reducing auto parallelism for vertex: " + getContext().getVertexName()
-        + " from " + pendingTasks.size() + " to " + finalTaskParallelism);
-
-    if(finalTaskParallelism < currentParallelism) {
-      // final parallelism is less than actual parallelism
-      Map<String, EdgeProperty> edgeProperties =
-          new HashMap<String, EdgeProperty>(bipartiteSources);
-      Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo();
-      for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
-        String vertex = entry.getKey();
-        EdgeProperty oldEdgeProp = entry.getValue().edgeProperty;
-        // use currentParallelism for numSourceTasks to maintain original state
-        // for the source tasks
-        CustomShuffleEdgeManagerConfig edgeManagerConfig =
-            new CustomShuffleEdgeManagerConfig(
-                currentParallelism, finalTaskParallelism, basePartitionRange,
-                ((remainderRangeForLastShuffler > 0) ?
-                    remainderRangeForLastShuffler : basePartitionRange));
-        EdgeManagerPluginDescriptor edgeManagerDescriptor =
-            EdgeManagerPluginDescriptor.create(CustomShuffleEdgeManager.class.getName());
-        edgeManagerDescriptor.setUserPayload(edgeManagerConfig.toUserPayload());
-        EdgeProperty newEdgeProp = EdgeProperty.create(edgeManagerDescriptor,
-            oldEdgeProp.getDataSourceType(), oldEdgeProp.getSchedulingType(), 
-            oldEdgeProp.getEdgeSource(), oldEdgeProp.getEdgeDestination());
-        edgeProperties.put(vertex, newEdgeProp);
-      }
+        (numShufflersWithBaseRange + 1) : (numShufflersWithBaseRange);
+
+    LOG.info("Reducing auto parallelism for vertex: {} from {} to {}",
+        getContext().getVertexName(), pendingTasks.size(),
+        finalTaskParallelism);
+
+    if(finalTaskParallelism >= currentParallelism) {
+      return null;
+    }
+
+    CustomShuffleEdgeManagerConfig edgeManagerConfig =
+        new CustomShuffleEdgeManagerConfig(
+            currentParallelism, finalTaskParallelism, basePartitionRange,
+            ((remainderRangeForLastShuffler > 0) ?
+            remainderRangeForLastShuffler : basePartitionRange));
+    EdgeManagerPluginDescriptor descriptor =
+        EdgeManagerPluginDescriptor.create(CustomShuffleEdgeManager.class.getName());
+    descriptor.setUserPayload(edgeManagerConfig.toUserPayload());
+    ReconfigVertexParams params = new ReconfigVertexParams(finalTaskParallelism, null, descriptor);
+    return params;
+  }
 
-      getContext().reconfigureVertex(finalTaskParallelism, null, edgeProperties);
-      updatePendingTasks();
-      configureTargetMapping(finalTaskParallelism);
-    }
-    return true;
+  @Override
+  void postReconfigVertex() {
+      configureTargetMapping(pendingTasks.size());
   }
 
-  void configureTargetMapping(int tasks) {
+  private void configureTargetMapping(int tasks) {
     targetIndexes = new int[tasks][];
     for (int idx = 0; idx < tasks; ++idx) {
       int partitionRange = basePartitionRange;
@@ -802,45 +532,45 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
       // skip the basePartitionRange per destination task
       targetIndexes[idx] = createIndices(partitionRange, idx, basePartitionRange);
       if (LOG.isDebugEnabled()) {
-        LOG.debug("targetIdx[" + idx + "] to " + Arrays.toString(targetIndexes[idx]));
+        LOG.debug("targetIdx[{}] to {}", idx,
+            Arrays.toString(targetIndexes[idx]));
       }
     }
   }
 
-  void schedulePendingTasks(int numTasksToSchedule, float minSourceVertexCompletedTaskFraction) {
-    // determine parallelism before scheduling the first time
-    // this is the latest we can wait before determining parallelism.
-    // currently this depends on task completion and so this is the best time
-    // to do this. This is the max time we have until we have to launch tasks 
-    // as specified by the user. If/When we move to some other method of 
-    // calculating parallelism or change parallelism while tasks are already
-    // running then we can create other parameters to trigger this calculation.
-    if(enableAutoParallelism && !parallelismDetermined) {
-      parallelismDetermined = determineParallelismAndApply(minSourceVertexCompletedTaskFraction);
-      if (!parallelismDetermined) {
-        //try to determine parallelism later when more info is available.
-        return;
+  /**
+   * Get the list of tasks to schedule based on the overall progress.
+   * Parameter completedSourceAttempt is part of the base class used by other
+   * VertexManagerPlugins; it isn't used here.
+   */
+  @Override
+  List<ScheduleTaskRequest> getTasksToSchedule(
+      TaskAttemptIdentifier completedSourceAttempt) {
+    float minSourceVertexCompletedTaskFraction =
+        getMinSourceVertexCompletedTaskFraction();
+    int numTasksToSchedule = getNumOfTasksToScheduleAndLog(
+        minSourceVertexCompletedTaskFraction);
+    if (numTasksToSchedule > 0) {
+      List<ScheduleTaskRequest> tasksToSchedule =
+          Lists.newArrayListWithCapacity(numTasksToSchedule);
+
+      while (!pendingTasks.isEmpty() && numTasksToSchedule > 0) {
+        numTasksToSchedule--;
+        Integer taskIndex = pendingTasks.get(0).getIndex();
+        tasksToSchedule.add(ScheduleTaskRequest.create(taskIndex, null));
+        pendingTasks.remove(0);
       }
-      getContext().doneReconfiguringVertex();
+      return tasksToSchedule;
     }
+    return null;
+  }
+
+  @Override
+  void processPendingTasks() {
     if (totalNumBipartiteSourceTasks > 0) {
       //Sort in case partition stats are available
       sortPendingTasksBasedOnDataSize();
     }
-    List<ScheduleTaskRequest> scheduledTasks = Lists.newArrayListWithCapacity(numTasksToSchedule);
-
-    while(!pendingTasks.isEmpty() && numTasksToSchedule > 0) {
-      numTasksToSchedule--;
-      Integer taskIndex = pendingTasks.get(0).index;
-      scheduledTasks.add(ScheduleTaskRequest.create(taskIndex, null));
-      pendingTasks.remove(0);
-    }
-
-    getContext().scheduleTasks(scheduledTasks);
-    if (pendingTasks.size() == 0) {
-      // done scheduling all tasks
-      // TODO TEZ-1714 locking issues. getContext().vertexManagerDone();
-    }
   }
 
   private void sortPendingTasksBasedOnDataSize() {
@@ -852,14 +582,14 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
       Collections.sort(pendingTasks, new Comparator<PendingTaskInfo>() {
         @Override
         public int compare(PendingTaskInfo left, PendingTaskInfo right) {
-          return (left.outputStats > right.outputStats) ? -1 :
-              ((left.outputStats == right.outputStats) ? 0 : 1);
+          return (left.getInputStats() > right.getInputStats()) ? -1 :
+              ((left.getInputStats() == right.getInputStats()) ? 0 : 1);
         }
       });
 
       if (LOG.isDebugEnabled()) {
         for (PendingTaskInfo pendingTask : pendingTasks) {
-          LOG.debug("Pending task:" + pendingTask.toString());
+          LOG.debug("Pending task: {}", pendingTask.toString());
         }
       }
     }
@@ -870,10 +600,10 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
    *
    * @return boolean indicating whether stats are computed
    */
-  private synchronized boolean computePartitionSizes() {
+  private boolean computePartitionSizes() {
     boolean computedPartitionSizes = false;
     for (PendingTaskInfo taskInfo : pendingTasks) {
-      int index = taskInfo.index;
+      int index = taskInfo.getIndex();
       if (targetIndexes != null) { //parallelism has changed.
         Preconditions.checkState(index < targetIndexes.length,
             "index=" + index +", targetIndexes length=" + targetIndexes.length);
@@ -882,228 +612,18 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
         for (int i : mapping) {
           totalStats += stats[i];
         }
-        if ((totalStats > 0) && (taskInfo.outputStats != totalStats)) {
-          computedPartitionSizes = true;
-          taskInfo.outputStats = totalStats;
-        }
+        computedPartitionSizes |= taskInfo.setInputStats(totalStats);
       } else {
-        if ((stats[index] > 0) && (stats[index] != taskInfo.outputStats)) {
-          computedPartitionSizes = true;
-          taskInfo.outputStats = stats[index];
-        }
+        computedPartitionSizes |= taskInfo.setInputStats(stats[index]);
       }
     }
     return computedPartitionSizes;
   }
 
-  /**
-   * Verify whether each of the source vertices have completed at least 1 task
-   *
-   * @return boolean
-   */
-  boolean canScheduleTasks() {
-    for(Map.Entry<String, SourceVertexInfo> entry : srcVertexInfo.entrySet()) {
-      // need to check for vertex configured because until that we dont know if numTasks==0 is valid
-      if (!entry.getValue().vertexIsConfigured) { // isConfigured
-        // vertex not scheduled tasks
-        if (LOG.isDebugEnabled()) {
-          LOG.debug("Waiting for vertex: " + entry.getKey() + " in vertex: "
-              + getContext().getVertexName());
-        }
-        return false;
-      }
-    }
-    sourceVerticesScheduled = true;
-    return sourceVerticesScheduled;
-  }
-
-  void schedulePendingTasks() {
-    if (!onVertexStartedDone.get()) {
-      // vertex not started yet
-      return;
-    }
-
-    if (!sourceVerticesScheduled && !canScheduleTasks()) {
-      if (LOG.isDebugEnabled()) {
-        LOG.debug("Defer scheduling tasks for vertex:" + getContext().getVertexName()
-            + " as one task needs to be completed per source vertex");
-      }
-      return;
-    }
-
-    int numPendingTasks = pendingTasks.size();
-    if (numBipartiteSourceTasksCompleted == totalNumBipartiteSourceTasks) {
-      LOG.info("All source tasks assigned. " +
-          "Ramping up " + numPendingTasks + 
-          " remaining tasks for vertex: " + getContext().getVertexName());
-      schedulePendingTasks(numPendingTasks, 1);
-      return;
-    }
-
-    float minSourceVertexCompletedTaskFraction = 1f;
-    String minCompletedVertexName = "";
-    for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) {
-      SourceVertexInfo srcInfo = vInfo.getValue();
-      // canScheduleTasks check has already verified all sources are configured
-      Preconditions.checkState(srcInfo.vertexIsConfigured, "Vertex: " + vInfo.getKey());
-      if (srcInfo.numTasks > 0) {
-        int numCompletedTasks = srcInfo.getNumCompletedTasks();
-        float completedFraction = (float) numCompletedTasks / srcInfo.numTasks;
-        if (minSourceVertexCompletedTaskFraction > completedFraction) {
-          minSourceVertexCompletedTaskFraction = completedFraction;
-          minCompletedVertexName = vInfo.getKey();
-        }
-      }
-    }
-
-    // start scheduling when source tasks completed fraction is more than min.
-    // linearly increase the number of scheduled tasks such that all tasks are 
-    // scheduled when source tasks completed fraction reaches max
-    float tasksFractionToSchedule = 1;
-    float percentRange = slowStartMaxSrcCompletionFraction - slowStartMinSrcCompletionFraction;
-    if (percentRange > 0) {
-      tasksFractionToSchedule = 
-            (minSourceVertexCompletedTaskFraction - slowStartMinSrcCompletionFraction)/
-            percentRange;
-    } else {
-      // min and max are equal. schedule 100% on reaching min
-      if(minSourceVertexCompletedTaskFraction < slowStartMinSrcCompletionFraction) {
-        tasksFractionToSchedule = 0;
-      }
-    }
-    
-    tasksFractionToSchedule = Math.max(0, Math.min(1, tasksFractionToSchedule));
-
-    // round up to avoid the corner case that single task cannot be scheduled until src completed
-    // fraction reach max
-    int numTasksToSchedule =
-        ((int)(Math.ceil(tasksFractionToSchedule * totalTasksToSchedule)) -
-         (totalTasksToSchedule - numPendingTasks));
-    
-    if (numTasksToSchedule > 0) {
-      // numTasksToSchedule can be -ve if numBipartiteSourceTasksCompleted does not
-      // does not increase monotonically
-      LOG.info("Scheduling " + numTasksToSchedule + " tasks for vertex: " + 
-               getContext().getVertexName() + " with totalTasks: " +
-               totalTasksToSchedule + ". " + numBipartiteSourceTasksCompleted +
-               " source tasks completed out of " + totalNumBipartiteSourceTasks +
-               ". MinSourceTaskCompletedFraction: " + minSourceVertexCompletedTaskFraction +
-               " in Vertex: " + minCompletedVertexName +
-               " min: " + slowStartMinSrcCompletionFraction + 
-               " max: " + slowStartMaxSrcCompletionFraction);
-      schedulePendingTasks(numTasksToSchedule, minSourceVertexCompletedTaskFraction);
-    }
-  }
-
-  @Override
-  public void initialize() {
-    Configuration conf;
-    try {
-      conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload());
-    } catch (IOException e) {
-      throw new TezUncheckedException(e);
-    }
 
-    this.slowStartMinSrcCompletionFraction = conf
-        .getFloat(
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION,
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
-    float defaultSlowStartMaxSrcFraction = ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT;
-    if (slowStartMinSrcCompletionFraction > defaultSlowStartMaxSrcFraction) {
-      defaultSlowStartMaxSrcFraction = slowStartMinSrcCompletionFraction;
-    }
-    this.slowStartMaxSrcCompletionFraction = conf
-        .getFloat(
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION,
-            defaultSlowStartMaxSrcFraction);
-
-    if (slowStartMinSrcCompletionFraction < 0 || slowStartMaxSrcCompletionFraction > 1
-        || slowStartMaxSrcCompletionFraction < slowStartMinSrcCompletionFraction) {
-      throw new IllegalArgumentException(
-          "Invalid values for slowStartMinSrcCompletionFraction"
-              + "/slowStartMaxSrcCompletionFraction. Min cannot be < 0, max cannot be > 1,"
-              + " and max cannot be < min."
-              + ", configuredMin=" + slowStartMinSrcCompletionFraction
-              + ", configuredMax=" + slowStartMaxSrcCompletionFraction);
-    }
 
-    enableAutoParallelism = conf
-        .getBoolean(
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT);
-    desiredTaskInputDataSize = conf
-        .getLong(
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT);
-    minTaskParallelism = Math.max(1, conf
-        .getInt(
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM,
-            ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT));
-    LOG.info("Shuffle Vertex Manager: settings" + " minFrac:"
-        + slowStartMinSrcCompletionFraction + " maxFrac:"
-        + slowStartMaxSrcCompletionFraction + " auto:" + enableAutoParallelism
-        + " desiredTaskIput:" + desiredTaskInputDataSize + " minTasks:"
-        + minTaskParallelism);
-
-    updatePendingTasks();
-    if (enableAutoParallelism) {
-      getContext().vertexReconfigurationPlanned();
-    }
-    // dont track the source tasks here since those tasks may themselves be
-    // dynamically changed as the DAG progresses.
 
-  }
 
-  private void handleVertexStateUpdate(VertexStateUpdate stateUpdate) {
-    Preconditions.checkArgument(stateUpdate.getVertexState() == VertexState.CONFIGURED,
-        "Received incorrect state notification : " + stateUpdate.getVertexState() + " for vertex: "
-            + stateUpdate.getVertexName() + " in vertex: " + getContext().getVertexName());
-    Preconditions.checkArgument(srcVertexInfo.containsKey(stateUpdate.getVertexName()),
-        "Received incorrect vertex notification : " + stateUpdate.getVertexState() + " for vertex: "
-            + stateUpdate.getVertexName() + " in vertex: " + getContext().getVertexName());
-    SourceVertexInfo vInfo = srcVertexInfo.get(stateUpdate.getVertexName()); 
-    Preconditions.checkState(vInfo.vertexIsConfigured == false);
-    vInfo.vertexIsConfigured = true;
-    vInfo.numTasks = getContext().getVertexNumTasks(stateUpdate.getVertexName());
-    if (vInfo.edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER) {
-      totalNumBipartiteSourceTasks += vInfo.numTasks;
-    }
-    LOG.info("Received configured notification : " + stateUpdate.getVertexState() + " for vertex: "
-      + stateUpdate.getVertexName() + " in vertex: " + getContext().getVertexName() + 
-      " numBipartiteSourceTasks: " + totalNumBipartiteSourceTasks);
-    schedulePendingTasks();
-  }
-  
-  @Override
-  public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
-    if (stateUpdate.getVertexState() == VertexState.CONFIGURED) {
-      // we will not register for updates until our vertex starts.
-      // derived classes can make other update requests for other states that we should
-      // ignore. However that will not be allowed until the state change notified supports
-      // multiple registers for the same vertex
-      if (onVertexStartedDone.get()) {
-        // normally this if check will always be true because we register after vertex
-        // start.
-        handleVertexStateUpdate(stateUpdate);
-      } else {
-        // normally this code will not trigger since we are the ones who register for
-        // the configured states updates and that will happen after vertex starts.
-        // So this code will only trigger if a derived class also registers for updates
-        // for the same vertices but multiple registers to the same vertex is currently
-        // not supported by the state change notifier code. This is just future proofing
-        // when that is supported
-        // vertex not started yet. So edge info may not have been defined correctly yet.
-        pendingStateUpdates.add(stateUpdate);
-      }
-    }
-  }
-  
-  @Override
-  public synchronized void onRootVertexInitialized(String inputName,
-      InputDescriptor inputDescriptor, List<Event> events) {
-    // Not allowing this for now. Nothing to do.
-  }
-  
   /**
    * Create a {@link VertexManagerPluginDescriptor} builder that can be used to
    * configure the plugin.
@@ -1115,9 +635,10 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
    *          then pass in a {@link Configuration} that is initialized from a
    *          config file. The parameters that are not overridden in code will
    *          be derived from the Configuration object.
-   * @return {@link org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager.ShuffleVertexManagerConfigBuilder}
+   * @return {@link ShuffleVertexManagerConfigBuilder}
    */
-  public static ShuffleVertexManagerConfigBuilder createConfigBuilder(@Nullable Configuration conf) {
+  public static ShuffleVertexManagerConfigBuilder createConfigBuilder(
+      @Nullable Configuration conf) {
     return new ShuffleVertexManagerConfigBuilder(conf);
   }
 
@@ -1135,36 +656,43 @@ public class ShuffleVertexManager extends VertexManagerPlugin {
       }
     }
 
-    public ShuffleVertexManagerConfigBuilder setAutoReduceParallelism(boolean enabled) {
-      conf.setBoolean(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL, enabled);
+    public ShuffleVertexManagerConfigBuilder setAutoReduceParallelism(
+        boolean enabled) {
+      conf.setBoolean(TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+          enabled);
       return this;
     }
 
-    public ShuffleVertexManagerConfigBuilder setSlowStartMinSrcCompletionFraction(float minFraction) {
-      conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, minFraction);
+    public ShuffleVertexManagerConfigBuilder
+        setSlowStartMinSrcCompletionFraction(float minFraction) {
+      conf.setFloat(TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, minFraction);
       return this;
     }
 
-    public ShuffleVertexManagerConfigBuilder setSlowStartMaxSrcCompletionFraction(float maxFraction) {
-      conf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, maxFraction);
+    public ShuffleVertexManagerConfigBuilder
+        setSlowStartMaxSrcCompletionFraction(float maxFraction) {
+      conf.setFloat(TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, maxFraction);
       return this;
     }
 
-    public ShuffleVertexManagerConfigBuilder setDesiredTaskInputSize(long desiredTaskInputSize) {
-      conf.setLong(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
+    public ShuffleVertexManagerConfigBuilder setDesiredTaskInputSize(
+        long desiredTaskInputSize) {
+      conf.setLong(TEZ_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
           desiredTaskInputSize);
       return this;
     }
 
-    public ShuffleVertexManagerConfigBuilder setMinTaskParallelism(int minTaskParallelism) {
-      conf.setInt(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM,
-          minTaskParallelism);
+    public ShuffleVertexManagerConfigBuilder setMinTaskParallelism(
+        int minTaskParallelism) {
+      conf.setInt(TEZ_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM,
+        minTaskParallelism);
       return this;
     }
 
     public VertexManagerPluginDescriptor build() {
       VertexManagerPluginDescriptor desc =
-          VertexManagerPluginDescriptor.create(ShuffleVertexManager.class.getName());
+          VertexManagerPluginDescriptor.create(
+              ShuffleVertexManager.class.getName());
 
       try {
         return desc.setUserPayload(TezUtils.createUserPayloadFromConf(this.conf));

http://git-wip-us.apache.org/repos/asf/tez/blob/9ca2d5be/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
new file mode 100644
index 0000000..951ce30
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
@@ -0,0 +1,768 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.dag.library.vertexmanager;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Predicate;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.InvalidProtocolBufferException;
+
+import org.apache.tez.common.TezCommonUtils;
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
+import org.apache.tez.dag.api.InputDescriptor;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.api.VertexManagerPlugin;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.dag.api.VertexLocationHint;
+import org.apache.tez.runtime.library.utils.DATA_RANGE_IN_MB;
+import org.roaringbitmap.RoaringBitmap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.classification.InterfaceStability.Evolving;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.apache.tez.runtime.api.TaskIdentifier;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
+
+
+import java.io.ByteArrayInputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.util.BitSet;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Starts scheduling tasks when number of completed source tasks crosses
+ * <code>slowStartMinFraction</code> and schedules all tasks
+ *  when <code>slowStartMaxFraction</code> is reached
+ */
+@Private
+@Evolving
+abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
+  static long MB = 1024l * 1024l;
+
+  private static final Logger LOG =
+     LoggerFactory.getLogger(ShuffleVertexManagerBase.class);
+
+  ComputeRoutingAction computeRoutingAction = ComputeRoutingAction.WAIT;
+
+  int totalNumBipartiteSourceTasks = 0;
+  int numBipartiteSourceTasksCompleted = 0;
+  int numVertexManagerEventsReceived = 0;
+  List<VertexManagerEvent> pendingVMEvents = Lists.newLinkedList();
+  AtomicBoolean onVertexStartedDone = new AtomicBoolean(false);
+
+  private Set<TaskIdentifier> taskWithVmEvents = Sets.newHashSet();
+
+  //Track source vertex and its finished tasks
+  private final Map<String, SourceVertexInfo> srcVertexInfo = Maps.newConcurrentMap();
+  boolean sourceVerticesScheduled = false;
+  @VisibleForTesting
+  int bipartiteSources = 0;
+  long completedSourceTasksOutputSize = 0;
+  List<VertexStateUpdate> pendingStateUpdates = Lists.newArrayList();
+  List<PendingTaskInfo> pendingTasks = Lists.newLinkedList();
+  int totalTasksToSchedule = 0;
+
+  @VisibleForTesting
+  long[] stats; //approximate amount of data to be fetched
+  Configuration conf;
+  ShuffleVertexManagerBaseConfig config;
+
+  /**
+   * Used when automatic parallelism is enabled
+   * Initially the vertex manager will start in WAIT state.
+   * After it gathers enough data, it will compute the new
+   * parallelism. In some special cases, it will skip the parallelism
+   * computation.
+   */
+  enum ComputeRoutingAction {
+    WAIT, // not enough data yet.
+    SKIP, // skip the routing computation
+    COMPUTE; // compute the new routing
+
+    public boolean determined() {
+      return this != WAIT;
+    }
+  }
+
+  static class SourceVertexInfo {
+    final EdgeProperty edgeProperty;
+    boolean vertexIsConfigured;
+    final BitSet finishedTaskSet;
+    int numTasks;
+    int numVMEventsReceived;
+    long outputSize;
+
+    SourceVertexInfo(final EdgeProperty edgeProperty) {
+      this.edgeProperty = edgeProperty;
+      this.finishedTaskSet = new BitSet();
+    }
+
+    int getNumTasks() {
+      return numTasks;
+    }
+
+    int getNumCompletedTasks() {
+      return finishedTaskSet.cardinality();
+    }
+  }
+
+  static class PendingTaskInfo {
+    final private int index;
+    private long inputStats;
+
+    public PendingTaskInfo(int index) {
+      this.index = index;
+    }
+
+    public String toString() {
+      return "[index=" + index + ", inputStats=" + inputStats + "]";
+    }
+    public int getIndex() {
+      return index;
+    }
+    public long getInputStats() {
+      return inputStats;
+    }
+    // return true if stat is set.
+    public boolean setInputStats(long inputStats) {
+      if (inputStats > 0 && this.inputStats != inputStats) {
+        this.inputStats = inputStats;
+        return true;
+      } else {
+        return false;
+      }
+    }
+  }
+
+  static class ReconfigVertexParams {
+    final private int finalParallelism;
+    final private VertexLocationHint locationHint;
+    final private EdgeManagerPluginDescriptor descriptor;
+
+    public ReconfigVertexParams(final int finalParallelism,
+        final VertexLocationHint locationHint,
+        final EdgeManagerPluginDescriptor descriptor) {
+      this.finalParallelism = finalParallelism;
+      this.locationHint = locationHint;
+      this.descriptor = descriptor;
+    }
+
+    public int getFinalParallelism() {
+      return finalParallelism;
+    }
+    public VertexLocationHint getLocationHint() {
+      return locationHint;
+    }
+    public EdgeManagerPluginDescriptor getDescriptor() {
+      return descriptor;
+    }
+  }
+
+  public ShuffleVertexManagerBase(VertexManagerPluginContext context) {
+    super(context);
+  }
+
+  @Override
+  public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions) {
+    // examine edges after vertex started because until then these may not have been defined
+    Map<String, EdgeProperty> inputs = getContext().getInputVertexEdgeProperties();
+    for(Map.Entry<String, EdgeProperty> entry : inputs.entrySet()) {
+      srcVertexInfo.put(entry.getKey(), new SourceVertexInfo(entry.getValue()));
+      // TODO what if derived class has already called this
+      // register for status update from all source vertices
+      getContext().registerForVertexStateUpdates(entry.getKey(),
+          EnumSet.of(VertexState.CONFIGURED));
+      if (entry.getValue().getDataMovementType() == DataMovementType.SCATTER_GATHER) {
+        bipartiteSources++;
+      }
+    }
+    if(bipartiteSources == 0) {
+      throw new TezUncheckedException("Atleast 1 bipartite source should exist");
+    }
+
+    for (VertexStateUpdate stateUpdate : pendingStateUpdates) {
+      handleVertexStateUpdate(stateUpdate);
+    }
+    pendingStateUpdates.clear();
+
+    // track the tasks in this vertex
+    updatePendingTasks();
+
+    for (VertexManagerEvent vmEvent : pendingVMEvents) {
+      handleVertexManagerEvent(vmEvent);
+    }
+    pendingVMEvents.clear();
+
+    LOG.info("OnVertexStarted vertex: {} with {} source tasks and {} pending" +
+        " tasks", getContext().getVertexName(), totalNumBipartiteSourceTasks,
+        totalTasksToSchedule);
+
+    if (completions != null) {
+      for (TaskAttemptIdentifier attempt : completions) {
+        onSourceTaskCompleted(attempt);
+      }
+    }
+    onVertexStartedDone.set(true);
+    // for the special case when source has 0 tasks or min fraction == 0
+    processPendingTasks(null);
+  }
+
+
+  @Override
+  public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
+    String srcVertexName = attempt.getTaskIdentifier().getVertexIdentifier().getName();
+    int srcTaskId = attempt.getTaskIdentifier().getIdentifier();
+    SourceVertexInfo srcInfo = srcVertexInfo.get(srcVertexName);
+    if (srcInfo.vertexIsConfigured) {
+      Preconditions.checkState(srcTaskId < srcInfo.numTasks,
+          "Received completion for srcTaskId " + srcTaskId + " but Vertex: " + srcVertexName +
+          " has only " + srcInfo.numTasks + " tasks");
+    }
+    //handle duplicate events and count task completions from all source vertices
+    BitSet completedSourceTasks = srcInfo.finishedTaskSet;
+    // duplicate notifications tracking
+    if (!completedSourceTasks.get(srcTaskId)) {
+      completedSourceTasks.set(srcTaskId);
+      // source task has completed
+      if (srcInfo.edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER) {
+        numBipartiteSourceTasksCompleted++;
+      }
+    }
+    processPendingTasks(attempt);
+  }
+
+  @VisibleForTesting
+  void parsePartitionStats(RoaringBitmap partitionStats) {
+    Preconditions.checkState(stats != null, "Stats should be initialized");
+    Iterator<Integer> it = partitionStats.iterator();
+    final DATA_RANGE_IN_MB[] RANGES = DATA_RANGE_IN_MB.values();
+    final int RANGE_LEN = RANGES.length;
+    while (it.hasNext()) {
+      int pos = it.next();
+      int index = ((pos) / RANGE_LEN);
+      int rangeIndex = ((pos) % RANGE_LEN);
+      //Add to aggregated stats and normalize to DATA_RANGE_IN_MB.
+      if (RANGES[rangeIndex].getSizeInMB() > 0) {
+        stats[index] += RANGES[rangeIndex].getSizeInMB();
+      }
+    }
+  }
+
+  void parseDetailedPartitionStats(List<Integer> partitionStats) {
+    for (int i=0; i<partitionStats.size(); i++) {
+      stats[i] += partitionStats.get(i);
+    }
+  }
+
+  @Override
+  public synchronized void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
+    if (onVertexStartedDone.get()) {
+      // internal data structures have been initialized - so handle the events directly
+      handleVertexManagerEvent(vmEvent);
+    } else {
+      // save this event for processing after vertex starts
+      pendingVMEvents.add(vmEvent);
+    }
+  }
+
+  private void handleVertexManagerEvent(VertexManagerEvent vmEvent) {
+    // currently events from multiple attempts of the same task can be ignored because
+    // their output will be the same.
+    TaskIdentifier producerTask = vmEvent.getProducerAttemptIdentifier().getTaskIdentifier();
+    if (!taskWithVmEvents.add(producerTask)) {
+      LOG.info("Ignoring vertex manager event from: {}", producerTask);
+      return;
+    }
+
+    String vName = producerTask.getVertexIdentifier().getName();
+    SourceVertexInfo srcInfo = srcVertexInfo.get(vName);
+    Preconditions.checkState(srcInfo != null,
+        "Unknown vmEvent from " + producerTask);
+
+    numVertexManagerEventsReceived++;
+
+    long sourceTaskOutputSize = 0;
+    if (vmEvent.getUserPayload() != null) {
+      // save output size
+      VertexManagerEventPayloadProto proto;
+      try {
+        proto = VertexManagerEventPayloadProto.parseFrom(
+            ByteString.copyFrom(vmEvent.getUserPayload()));
+      } catch (InvalidProtocolBufferException e) {
+        throw new TezUncheckedException(e);
+      }
+      sourceTaskOutputSize = proto.getOutputSize();
+
+      if (proto.hasPartitionStats()) {
+        try {
+          RoaringBitmap partitionStats = new RoaringBitmap();
+          ByteString compressedPartitionStats = proto.getPartitionStats();
+          byte[] rawData = TezCommonUtils.decompressByteStringToByteArray(
+              compressedPartitionStats);
+          ByteArrayInputStream bin = new ByteArrayInputStream(rawData);
+          partitionStats.deserialize(new DataInputStream(bin));
+
+          parsePartitionStats(partitionStats);
+
+        } catch (IOException e) {
+          throw new TezUncheckedException(e);
+        }
+      } else if (proto.hasDetailedPartitionStats()) {
+        List<Integer> detailedPartitionStats =
+            proto.getDetailedPartitionStats().getSizeInMbList();
+        parseDetailedPartitionStats(detailedPartitionStats);
+      }
+      srcInfo.numVMEventsReceived++;
+      srcInfo.outputSize += sourceTaskOutputSize;
+      completedSourceTasksOutputSize += sourceTaskOutputSize;
+    }
+
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("For attempt: {} received info of output size: {}"
+          + " vertex numEventsReceived: {} vertex output size: {}"
+          + " total numEventsReceived: {} total output size: {}",
+          vmEvent.getProducerAttemptIdentifier(), sourceTaskOutputSize,
+          srcInfo.numVMEventsReceived, srcInfo.outputSize,
+          numVertexManagerEventsReceived, completedSourceTasksOutputSize);
+    }
+  }
+
+  void updatePendingTasks() {
+    int tasks = getContext().getVertexNumTasks(getContext().getVertexName());
+    if (tasks == pendingTasks.size() || tasks <= 0) {
+      return;
+    }
+    pendingTasks.clear();
+    for (int i = 0; i < tasks; ++i) {
+      pendingTasks.add(new PendingTaskInfo(i));
+    }
+    totalTasksToSchedule = pendingTasks.size();
+    if (stats == null) {
+      stats = new long[totalTasksToSchedule]; // TODO lost previous data
+    }
+  }
+
+  /**
+   * Beginning of functions related to how new parallelism is determined.
+   * ShuffleVertexManagerBase implements common functionality used by
+   * VertexManagerPlugin, while each VertexManagerPlugin implements its own
+   * routing policy.
+   */
+  private ComputeRoutingAction getComputeRoutingAction(
+      float minSourceVertexCompletedTaskFraction) {
+    if (getNumOfTasksToSchedule(minSourceVertexCompletedTaskFraction) <= 0 &&
+        numBipartiteSourceTasksCompleted != totalNumBipartiteSourceTasks) {
+      // Wait when there aren't enough completed tasks
+      return ComputeRoutingAction.WAIT;
+    } else if (numVertexManagerEventsReceived == 0 &&
+      totalNumBipartiteSourceTasks > 0) {
+      // When source tasks don't have output data,
+      // there will be no VME.
+      return ComputeRoutingAction.SKIP;
+    } else if (
+        completedSourceTasksOutputSize < config.getDesiredTaskInputDataSize()
+        && (minSourceVertexCompletedTaskFraction < config.getMaxFraction())) {
+      /**
+       * When overall completed output size is not even equal to
+       * desiredTaskInputSize, we can wait for some more data to be available to
+       * determine better parallelism until max.fraction is reached.
+       * min.fraction is just a hint to the framework and need not be
+       * honored strictly in this case.
+       */
+      LOG.info("Defer scheduling tasks; vertex = {}"
+          + ", totalNumBipartiteSourceTasks = {}"
+          + ", completedSourceTasksOutputSize = {}"
+          + ", numVertexManagerEventsReceived = {}"
+          + ", numBipartiteSourceTasksCompleted = {}"
+          + ", minSourceVertexCompletedTaskFraction = {}",
+          getContext().getVertexName(), totalNumBipartiteSourceTasks,
+          completedSourceTasksOutputSize, numVertexManagerEventsReceived,
+          numBipartiteSourceTasksCompleted,
+          minSourceVertexCompletedTaskFraction);
+       return ComputeRoutingAction.WAIT;
+    } else {
+      return ComputeRoutingAction.COMPUTE;
+    }
+  }
+
+  /**
+   * Subclass might return null to indicate there is no new routing.
+   */
+  abstract ReconfigVertexParams computeRouting();
+
+  abstract void postReconfigVertex();
+
+  /**
+   * Compute optimal parallelism needed for the job
+   * @return true (if parallelism is determined), false otherwise
+   */
+  @VisibleForTesting
+  boolean determineParallelismAndApply(
+      float minSourceVertexCompletedTaskFraction) {
+    if (computeRoutingAction.equals(ComputeRoutingAction.WAIT)) {
+      ComputeRoutingAction computeRoutingAction = getComputeRoutingAction(
+          minSourceVertexCompletedTaskFraction);
+      if (computeRoutingAction.equals(computeRoutingAction.COMPUTE)) {
+        ReconfigVertexParams params = computeRouting();
+        if (params != null) {
+          reconfigVertex(params.getFinalParallelism(), params.getDescriptor());
+          updatePendingTasks();
+          postReconfigVertex();
+        }
+      }
+      if (!computeRoutingAction.equals(ComputeRoutingAction.WAIT)) {
+        getContext().doneReconfiguringVertex();
+      }
+      this.computeRoutingAction = computeRoutingAction;
+    }
+    return this.computeRoutingAction.determined();
+  }
+
+  private boolean determineParallelismAndApply() {
+    return determineParallelismAndApply(
+        getMinSourceVertexCompletedTaskFraction());
+  }
+  /**
+   * End of functions related to how new parallelism is determined.
+   */
+
+
+  /**
+   * Subclass might return null or empty list to indicate no tasks
+   * to schedule at this point.
+   * @param completedSourceAttempt the completed source task attempt
+   * @return the list of tasks to schedule.
+   */
+  abstract List<ScheduleTaskRequest> getTasksToSchedule(
+      TaskAttemptIdentifier completedSourceAttempt);
+
+  abstract void processPendingTasks();
+
+  private void schedulePendingTasks(
+      TaskAttemptIdentifier completedSourceAttempt) {
+    List<ScheduleTaskRequest> scheduledTasks =
+        getTasksToSchedule(completedSourceAttempt);
+    if (scheduledTasks != null && scheduledTasks.size() > 0) {
+      getContext().scheduleTasks(scheduledTasks);
+    }
+  }
+
+  Iterable<Map.Entry<String, SourceVertexInfo>> getBipartiteInfo() {
+    return Iterables.filter(srcVertexInfo.entrySet(),
+        new Predicate<Map.Entry<String,SourceVertexInfo>>() {
+      public boolean apply(Map.Entry<String, SourceVertexInfo> input) {
+        return (input.getValue().edgeProperty.getDataMovementType() ==
+            DataMovementType.SCATTER_GATHER);
+      }
+    });
+  }
+
+  /**
+   * Verify whether each of the source vertices have completed at least 1 task
+   *
+   * @return boolean
+   */
+  private boolean canScheduleTasks() {
+    for(Map.Entry<String, SourceVertexInfo> entry : srcVertexInfo.entrySet()) {
+      // need to check for vertex configured because until that we dont know
+      // if numTask s== 0 is valid
+      if (!entry.getValue().vertexIsConfigured) { // isConfigured
+        // vertex not scheduled tasks
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Waiting for vertex: {} in vertex: {}", entry.getKey(),
+              getContext().getVertexName());
+        }
+        return false;
+      }
+    }
+    sourceVerticesScheduled = true;
+    return sourceVerticesScheduled;
+  }
+
+  int getNumOfTasksToScheduleAndLog(float minFraction) {
+    int numTasksToSchedule = getNumOfTasksToSchedule(minFraction);
+    if (numTasksToSchedule > 0) {
+      // numTasksToSchedule can be -ve if minFraction
+      // is less than slowStartMinSrcCompletionFraction.
+      LOG.info("Scheduling {} tasks for vertex: {} with totalTasks: {}. " +
+          "{} source tasks completed out of {}. " +
+          "MinSourceTaskCompletedFraction: {} min: {} max: {}",
+          numTasksToSchedule, getContext().getVertexName(),
+          totalTasksToSchedule, numBipartiteSourceTasksCompleted,
+          totalNumBipartiteSourceTasks, minFraction,
+          config.getMinFraction(), config.getMaxFraction());
+    }
+    return numTasksToSchedule;
+  }
+
+  int getNumOfTasksToSchedule(float minSourceVertexCompletedTaskFraction) {
+    int numPendingTasks = pendingTasks.size();
+    if (numBipartiteSourceTasksCompleted == totalNumBipartiteSourceTasks) {
+      LOG.info("All source tasks completed. Ramping up {} remaining tasks" +
+          " for vertex: {}", numPendingTasks, getContext().getVertexName());
+      return numPendingTasks;
+    }
+
+    // start scheduling when source tasks completed fraction is more than min.
+    // linearly increase the number of scheduled tasks such that all tasks are
+    // scheduled when source tasks completed fraction reaches max
+    float tasksFractionToSchedule = 1;
+    float percentRange =
+        config.getMaxFraction() - config.getMinFraction();
+    if (percentRange > 0) {
+      tasksFractionToSchedule =
+          (minSourceVertexCompletedTaskFraction -
+              config.getMinFraction()) / percentRange;
+    } else {
+      // min and max are equal. schedule 100% on reaching min
+      if(minSourceVertexCompletedTaskFraction <
+          config.getMinFraction()) {
+        tasksFractionToSchedule = 0;
+      }
+    }
+
+    tasksFractionToSchedule =
+        Math.max(0, Math.min(1, tasksFractionToSchedule));
+
+    // round up to avoid the corner case that single task cannot be scheduled
+    // until src completed fraction reach max
+    return ((int)(Math.ceil(tasksFractionToSchedule * totalTasksToSchedule)) -
+        (totalTasksToSchedule - numPendingTasks));
+  }
+
+  float getMinSourceVertexCompletedTaskFraction() {
+    float minSourceVertexCompletedTaskFraction = 1f;
+
+    if (numBipartiteSourceTasksCompleted != totalNumBipartiteSourceTasks) {
+      for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) {
+        SourceVertexInfo srcInfo = vInfo.getValue();
+        // canScheduleTasks check has already verified all sources are configured
+        Preconditions.checkState(srcInfo.vertexIsConfigured,
+            "Vertex: " + vInfo.getKey());
+        if (srcInfo.numTasks > 0) {
+          int numCompletedTasks = srcInfo.getNumCompletedTasks();
+          float completedFraction =
+              (float) numCompletedTasks / srcInfo.numTasks;
+          if (minSourceVertexCompletedTaskFraction > completedFraction) {
+            minSourceVertexCompletedTaskFraction = completedFraction;
+          }
+        }
+      }
+    }
+    return minSourceVertexCompletedTaskFraction;
+  }
+
+
+  private boolean preconditionsSatisfied() {
+    if (!onVertexStartedDone.get()) {
+      // vertex not started yet
+      return false;
+    }
+
+    if (!sourceVerticesScheduled && !canScheduleTasks()) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Defer scheduling tasks for vertex: {} as one task needs " +
+            "to be completed per source vertex", getContext().getVertexName());
+      }
+      return false;
+    }
+    return true;
+  }
+
+  /**
+   * Process pending tasks when a source task has completed.
+   * The processing goes through 4 steps.
+   * Step 1: Precondition check such as whether the vertex has started.
+   * Step 2: Determine new parallelism if possible.
+   * Step 3: Process pending tasks such as sorting based on size.
+   * Step 4: Schedule the pending tasks.
+   * @param completedSourceAttempt
+   */
+  private void processPendingTasks(TaskAttemptIdentifier completedSourceAttempt) {
+    if (!preconditionsSatisfied()) {
+      return;
+    }
+
+    // determine parallelism before scheduling the first time
+    // this is the latest we can wait before determining parallelism.
+    // currently this depends on task completion and so this is the best time
+    // to do this. This is the max time we have until we have to launch tasks
+    // as specified by the user. If/When we move to some other method of
+    // calculating parallelism or change parallelism while tasks are already
+    // running then we can create other parameters to trigger this calculation.
+    if(config.isAutoParallelismEnabled()) {
+      if (!determineParallelismAndApply()) {
+        //try to determine parallelism later when more info is available.
+        return;
+      }
+    }
+    processPendingTasks();
+    schedulePendingTasks(completedSourceAttempt);
+  }
+
+  static class ShuffleVertexManagerBaseConfig {
+    final private boolean enableAutoParallelism;
+    final private long desiredTaskInputDataSize;
+    final private float slowStartMinFraction;
+    final private float slowStartMaxFraction;
+    public ShuffleVertexManagerBaseConfig(final boolean enableAutoParallelism,
+        final long desiredTaskInputDataSize, final float slowStartMinFraction,
+        final float slowStartMaxFraction) {
+      if (slowStartMinFraction < 0 || slowStartMaxFraction > 1
+          || slowStartMaxFraction < slowStartMinFraction) {
+        throw new IllegalArgumentException(
+            "Invalid values for slowStartMinFraction"
+                + "/slowStartMaxFraction. Min "
+                + "cannot be < 0, max cannot be > 1, and max cannot be < min."
+                + ", configuredMin=" + slowStartMinFraction
+                + ", configuredMax=" + slowStartMaxFraction);
+      }
+
+      this.enableAutoParallelism = enableAutoParallelism;
+      this.desiredTaskInputDataSize = desiredTaskInputDataSize;
+      this.slowStartMinFraction = slowStartMinFraction;
+      this.slowStartMaxFraction = slowStartMaxFraction;
+
+      LOG.info("Settings minFrac: {} maxFrac: {} auto: {} desiredTaskIput: {}",
+          slowStartMinFraction, slowStartMaxFraction, enableAutoParallelism,
+          desiredTaskInputDataSize);
+    }
+
+    public boolean isAutoParallelismEnabled() {
+      return this.enableAutoParallelism;
+    }
+    public long getDesiredTaskInputDataSize() {
+      return this.desiredTaskInputDataSize;
+    }
+    public float getMinFraction() {
+      return this.slowStartMinFraction;
+    }
+    public float getMaxFraction() {
+      return this.slowStartMaxFraction;
+    }
+  }
+
+  abstract ShuffleVertexManagerBaseConfig initConfiguration();
+
+  @Override
+  public void initialize() {
+    try {
+      conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload());
+    } catch (IOException e) {
+      throw new TezUncheckedException(e);
+    }
+    config = initConfiguration();
+    updatePendingTasks();
+    if (config.isAutoParallelismEnabled()) {
+      getContext().vertexReconfigurationPlanned();
+    }
+    // dont track the source tasks here since those tasks may themselves be
+    // dynamically changed as the DAG progresses.
+  }
+
+  private void handleVertexStateUpdate(VertexStateUpdate stateUpdate) {
+    Preconditions.checkArgument(stateUpdate.getVertexState() == VertexState.CONFIGURED,
+        "Received incorrect state notification : " + stateUpdate.getVertexState() + " for vertex: "
+            + stateUpdate.getVertexName() + " in vertex: " + getContext().getVertexName());
+    Preconditions.checkArgument(srcVertexInfo.containsKey(stateUpdate.getVertexName()),
+        "Received incorrect vertex notification : " + stateUpdate.getVertexState() + " for vertex: "
+            + stateUpdate.getVertexName() + " in vertex: " + getContext().getVertexName());
+    SourceVertexInfo vInfo = srcVertexInfo.get(stateUpdate.getVertexName());
+    Preconditions.checkState(vInfo.vertexIsConfigured == false);
+    vInfo.vertexIsConfigured = true;
+    vInfo.numTasks = getContext().getVertexNumTasks(stateUpdate.getVertexName());
+    if (vInfo.edgeProperty.getDataMovementType() == DataMovementType.SCATTER_GATHER) {
+      totalNumBipartiteSourceTasks += vInfo.numTasks;
+    }
+    LOG.info("Received configured notification : {}" + " for vertex: {} in" +
+        " vertex: {}" + " numBipartiteSourceTasks: {}",
+        stateUpdate.getVertexState(), stateUpdate.getVertexName(),
+        getContext().getVertexName(), totalNumBipartiteSourceTasks);
+    processPendingTasks(null);
+  }
+
+  @Override
+  public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
+    if (stateUpdate.getVertexState() == VertexState.CONFIGURED) {
+      // we will not register for updates until our vertex starts.
+      // derived classes can make other update requests for other states that we should
+      // ignore. However that will not be allowed until the state change notified supports
+      // multiple registers for the same vertex
+      if (onVertexStartedDone.get()) {
+        // normally this if check will always be true because we register after vertex
+        // start.
+        handleVertexStateUpdate(stateUpdate);
+      } else {
+        // normally this code will not trigger since we are the ones who register for
+        // the configured states updates and that will happen after vertex starts.
+        // So this code will only trigger if a derived class also registers for updates
+        // for the same vertices but multiple registers to the same vertex is currently
+        // not supported by the state change notifier code. This is just future proofing
+        // when that is supported
+        // vertex not started yet. So edge info may not have been defined correctly yet.
+        pendingStateUpdates.add(stateUpdate);
+      }
+    }
+  }
+
+  @Override
+  public synchronized void onRootVertexInitialized(String inputName,
+      InputDescriptor inputDescriptor, List<Event> events) {
+    // Not allowing this for now. Nothing to do.
+  }
+
+  private void reconfigVertex(final int finalTaskParallelism,
+      final EdgeManagerPluginDescriptor edgeManagerDescriptor) {
+    Map<String, EdgeProperty> edgeProperties =
+        new HashMap<String, EdgeProperty>(bipartiteSources);
+    Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo();
+    for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
+      String vertex = entry.getKey();
+      EdgeProperty oldEdgeProp = entry.getValue().edgeProperty;
+      EdgeProperty newEdgeProp = EdgeProperty.create(edgeManagerDescriptor,
+          oldEdgeProp.getDataSourceType(), oldEdgeProp.getSchedulingType(),
+          oldEdgeProp.getEdgeSource(), oldEdgeProp.getEdgeDestination());
+      edgeProperties.put(vertex, newEdgeProp);
+    }
+
+    getContext().reconfigureVertex(finalTaskParallelism, null, edgeProperties);
+  }
+}


Mime
View raw message