tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeag...@apache.org
Subject tez git commit: TEZ-3274. Vertex with MRInput and broadcast input does not respect slow start (Eric Badger via jeagles)
Date Fri, 30 Jun 2017 20:10:15 GMT
Repository: tez
Updated Branches:
  refs/heads/master 59f57c10d -> b32b66f54


TEZ-3274. Vertex with MRInput and broadcast input does not respect slow start (Eric Badger via jeagles)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/b32b66f5
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/b32b66f5
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/b32b66f5

Branch: refs/heads/master
Commit: b32b66f5481ba785e8ecfbd62151ae8ad0739c68
Parents: 59f57c1
Author: Jonathan Eagles <jeagles@yahoo-inc.com>
Authored: Fri Jun 30 15:09:10 2017 -0500
Committer: Jonathan Eagles <jeagles@yahoo-inc.com>
Committed: Fri Jun 30 15:09:10 2017 -0500

----------------------------------------------------------------------
 .../app/dag/impl/RootInputVertexManager.java    | 450 +++++++++++++++++-
 .../apache/tez/dag/app/dag/impl/VertexImpl.java |   4 +-
 .../dag/impl/TestRootInputVertexManager.java    | 459 ++++++++++++++++++-
 .../tez/dag/app/dag/impl/TestVertexImpl.java    |   3 +-
 .../tez/dag/app/dag/impl/TestVertexManager.java |   9 +-
 .../tez/test/TestExceptionPropagation.java      |  40 +-
 6 files changed, 947 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/b32b66f5/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
index c1e96f3..3205983 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
@@ -18,35 +18,259 @@
 
 package org.apache.tez.dag.app.dag.impl;
 
+import java.io.IOException;
+import java.util.BitSet;
+import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
 
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
+import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.InputSpecUpdate;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
 import org.apache.tez.runtime.api.events.InputConfigureVertexTasksEvent;
 import org.apache.tez.runtime.api.events.InputDataInformationEvent;
 import org.apache.tez.runtime.api.events.InputUpdatePayloadEvent;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 
-public class RootInputVertexManager extends ImmediateStartVertexManager {
+import javax.annotation.Nullable;
+
+public class RootInputVertexManager extends VertexManagerPlugin {
 
   private static final Logger LOG = 
       LoggerFactory.getLogger(RootInputVertexManager.class);
+
+  /**
+   * Enables slow start for the vertex. Based on min/max fraction configs
+   */
+  public static final String TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START =
+      "tez.root-input-vertex-manager.enable.slow-start";
+  public static final boolean
+      TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START_DEFAULT = false;
+
+  /**
+   * In case of a Broadcast connection, the fraction of source tasks which
+   * should complete before tasks for the current vertex are scheduled
+   */
+  public static final String TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION =
+      "tez.root-input-vertex-manager.min-src-fraction";
+  public static final float
+      TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT = 0.25f;
+
+  /**
+   * In case of a Broadcast connection, once this fraction of source tasks
+   * have completed, all tasks on the current vertex can be scheduled. Number of
+   * tasks ready for scheduling on the current vertex scales linearly between
+   * min-fraction and max-fraction. Defaults to the greater of the default value
+   * or tez.root-input-vertex-manager.min-src-fraction.
+   */
+  public static final String TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION =
+      "tez.root-input-vertex-manager.max-src-fraction";
+  public static final float
+      TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT = 0.75f;
   
   private String configuredInputName;
 
+  int totalNumSourceTasks = 0;
+  int numSourceTasksCompleted = 0;
+  private AtomicBoolean onVertexStartedDone = new AtomicBoolean(false);
+
+  private final Map<String, SourceVertexInfo> srcVertexInfo = new
+      ConcurrentHashMap<>();
+  boolean sourceVerticesScheduled = false;
+  List<PendingTaskInfo> pendingTasks = Lists.newLinkedList();
+  int totalTasksToSchedule = 0;
+
+  boolean slowStartEnabled = false;
+  float slowStartMinFraction = 0;
+  float slowStartMaxFraction = 0;
+
+  @VisibleForTesting
+  Configuration conf;
+
+  static class PendingTaskInfo {
+    final private int index;
+
+    public PendingTaskInfo(int index) {
+      this.index = index;
+    }
+
+    public String toString() {
+      return "[index=" + index + "]";
+    }
+    public int getIndex() {
+      return index;
+    }
+  }
+
+  static class SourceVertexInfo {
+    final EdgeProperty edgeProperty;
+    boolean vertexIsConfigured;
+    final BitSet finishedTaskSet;
+    int numTasks;
+
+    SourceVertexInfo(final EdgeProperty edgeProperty,
+       int totalTasksToSchedule) {
+      this.edgeProperty = edgeProperty;
+      this.finishedTaskSet = new BitSet();
+    }
+
+    int getNumTasks() {
+      return numTasks;
+    }
+
+    int getNumCompletedTasks() {
+      return finishedTaskSet.cardinality();
+    }
+  }
+
+  SourceVertexInfo createSourceVertexInfo(EdgeProperty edgeProperty,
+      int numTasks) {
+    return new SourceVertexInfo(edgeProperty, numTasks);
+  }
+
   public RootInputVertexManager(VertexManagerPluginContext context) {
     super(context);
   }
 
+  @Override
+  public void onVertexStarted(List<TaskAttemptIdentifier> completions) {
+    Map<String, EdgeProperty> edges = getContext().
+        getInputVertexEdgeProperties();
+    for (Map.Entry<String, EdgeProperty> entry : edges.entrySet()) {
+      String srcVertex = entry.getKey();
+      //track vertices with task count > 0
+      int numTasks = getContext().getVertexNumTasks(srcVertex);
+      if (numTasks > 0) {
+        LOG.info("Task count in " + srcVertex + ": " + numTasks);
+        srcVertexInfo.put(srcVertex, createSourceVertexInfo(entry.getValue(),
+            getContext().getVertexNumTasks(getContext().getVertexName())));
+        getContext().registerForVertexStateUpdates(srcVertex,
+            EnumSet.of(VertexState.CONFIGURED));
+      } else {
+        LOG.info("Vertex: " + getContext().getVertexName() + "; Ignoring "
+            + srcVertex + " as it has " + numTasks + " tasks");
+      }
+    }
+    onVertexStartedDone.set(true);
+    // track the tasks in this vertex
+    updatePendingTasks();
+    processPendingTasks();
+  }
+
+  @Override
+  public void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
+    Preconditions.checkArgument(stateUpdate.getVertexState() ==
+        VertexState.CONFIGURED,
+        "Received incorrect state notification : "
+        + stateUpdate.getVertexState() + " for vertex: "
+        + stateUpdate.getVertexName() + " in vertex: "
+        + getContext().getVertexName());
+
+    SourceVertexInfo vInfo = srcVertexInfo.get(stateUpdate.getVertexName());
+    if(vInfo != null) {
+      Preconditions.checkState(vInfo.vertexIsConfigured == false);
+      vInfo.vertexIsConfigured = true;
+      vInfo.numTasks = getContext().getVertexNumTasks(
+          stateUpdate.getVertexName());
+      totalNumSourceTasks += vInfo.numTasks;
+      LOG.info("Received configured notification : {}" + " for vertex: {} in" +
+          " vertex: {}" + " numjourceTasks: {}",
+        stateUpdate.getVertexState(), stateUpdate.getVertexName(),
+        getContext().getVertexName(), totalNumSourceTasks);
+      processPendingTasks();
+    }
+  }
+
+  @Override
+  public void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
+    String srcVertexName = attempt.getTaskIdentifier().getVertexIdentifier()
+        .getName();
+    int srcTaskId = attempt.getTaskIdentifier().getIdentifier();
+    SourceVertexInfo srcInfo = srcVertexInfo.get(srcVertexName);
+    if (srcInfo.vertexIsConfigured) {
+      Preconditions.checkState(srcTaskId < srcInfo.numTasks,
+          "Received completion for srcTaskId " + srcTaskId + " but Vertex: "
+          + srcVertexName + " has only " + srcInfo.numTasks + " tasks");
+    }
+    // handle duplicate events and count task completions from
+    // all source vertices
+    BitSet completedSourceTasks = srcInfo.finishedTaskSet;
+    // duplicate notifications tracking
+    if (!completedSourceTasks.get(srcTaskId)) {
+      completedSourceTasks.set(srcTaskId);
+      // source task has completed
+      numSourceTasksCompleted++;
+    }
+    processPendingTasks();
+  }
+
+  @Override
+  public void initialize() {
+    UserPayload userPayload = getContext().getUserPayload();
+    if (userPayload == null || userPayload.getPayload() == null ||
+        userPayload.getPayload().limit() == 0) {
+      throw new RuntimeException("Could not initialize RootInputVertexManager"
+          + " from provided user payload");
+    }
+    try {
+      conf = TezUtils.createConfFromUserPayload(userPayload);
+    } catch (IOException e) {
+      throw new RuntimeException("Could not initialize RootInputVertexManager"
+        + " from provided user payload", e);
+    }
+    slowStartEnabled = conf.getBoolean(
+      TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START,
+      TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START_DEFAULT);
+
+    if(slowStartEnabled) {
+      slowStartMinFraction = conf.getFloat(
+        TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION,
+        TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
+      slowStartMaxFraction = conf.getFloat(
+        TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION,
+        Math.max(slowStartMinFraction,
+            TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT));
+    } else {
+      slowStartMinFraction = 0;
+      slowStartMaxFraction = 0;
+    }
+    if (slowStartMinFraction < 0 || slowStartMaxFraction > 1
+        || slowStartMaxFraction < slowStartMinFraction) {
+      throw new IllegalArgumentException(
+          "Invalid values for slowStartMinFraction"
+          + "/slowStartMaxFraction. Min "
+          + "cannot be < 0, max cannot be > 1, and max cannot be < min."
+          + ", configuredMin=" + slowStartMinFraction
+          + ", configuredMax=" + slowStartMaxFraction);
+    }
+
+
+    updatePendingTasks();
+  }
+
+  @Override
+  public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
+  }
 
   @Override
   public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor,
@@ -97,4 +321,228 @@ public class RootInputVertexManager extends ImmediateStartVertexManager {
     }
     getContext().addRootInputEvents(inputName, riEvents);
   }
+
+  private boolean canScheduleTasks() {
+    // check for source vertices completely configured
+    for (Map.Entry<String, SourceVertexInfo> entry : srcVertexInfo.entrySet()) {
+      if (!entry.getValue().vertexIsConfigured) {
+        // vertex not configured
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Waiting for vertex: " + entry.getKey() + " in vertex: "
+              + getContext().getVertexName());
+        }
+        return false;
+      }
+    }
+
+    sourceVerticesScheduled = true;
+    return sourceVerticesScheduled;
+  }
+
+  private boolean preconditionsSatisfied() {
+    if (!onVertexStartedDone.get()) {
+      // vertex not started yet
+      return false;
+    }
+
+    if (!sourceVerticesScheduled && !canScheduleTasks()) {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Defer scheduling tasks for vertex: {} as one task needs " +
+            "to be completed per source vertex", getContext().getVertexName());
+      }
+      return false;
+    }
+    return true;
+  }
+
+  void updatePendingTasks() {
+    int tasks = getContext().getVertexNumTasks(getContext().getVertexName());
+    if (tasks == pendingTasks.size() || tasks <= 0) {
+      return;
+    }
+    pendingTasks.clear();
+    for (int i = 0; i < tasks; ++i) {
+      pendingTasks.add(new PendingTaskInfo(i));
+    }
+    totalTasksToSchedule = pendingTasks.size();
+  }
+
+  private void processPendingTasks() {
+    if (!preconditionsSatisfied()) {
+      return;
+    }
+    schedulePendingTasks();
+  }
+
+  private void schedulePendingTasks() {
+    List<VertexManagerPluginContext.ScheduleTaskRequest> scheduledTasks =
+        getTasksToSchedule();
+    if (scheduledTasks != null && scheduledTasks.size() > 0) {
+      getContext().scheduleTasks(scheduledTasks);
+    }
+  }
+
+  float getMinSourceVertexCompletedTaskFraction() {
+    float minSourceVertexCompletedTaskFraction = 1f;
+
+    if (numSourceTasksCompleted != totalNumSourceTasks) {
+      for (Map.Entry<String, SourceVertexInfo> vInfo :
+          srcVertexInfo.entrySet()) {
+        SourceVertexInfo srcInfo = vInfo.getValue();
+        // canScheduleTasks check has already verified all
+        // sources are configured
+        Preconditions.checkState(srcInfo.vertexIsConfigured,
+            "Vertex: " + vInfo.getKey());
+        if (srcInfo.numTasks > 0) {
+          int numCompletedTasks = srcInfo.getNumCompletedTasks();
+          float completedFraction =
+              (float) numCompletedTasks / srcInfo.numTasks;
+          if (minSourceVertexCompletedTaskFraction > completedFraction) {
+            minSourceVertexCompletedTaskFraction = completedFraction;
+          }
+        }
+      }
+    }
+    return minSourceVertexCompletedTaskFraction;
+  }
+
+  List<VertexManagerPluginContext.ScheduleTaskRequest> getTasksToSchedule() {
+    float minSourceVertexCompletedTaskFraction =
+        getMinSourceVertexCompletedTaskFraction();
+    int numTasksToSchedule = getNumOfTasksToScheduleAndLog(
+        minSourceVertexCompletedTaskFraction);
+    if (numTasksToSchedule > 0) {
+      List<VertexManagerPluginContext.ScheduleTaskRequest> tasksToSchedule =
+          Lists.newArrayListWithCapacity(numTasksToSchedule);
+
+      while (!pendingTasks.isEmpty() && numTasksToSchedule > 0) {
+        numTasksToSchedule--;
+        Integer taskIndex = pendingTasks.get(0).getIndex();
+        tasksToSchedule.add(VertexManagerPluginContext.ScheduleTaskRequest
+            .create(taskIndex, null));
+        pendingTasks.remove(0);
+      }
+      return tasksToSchedule;
+    }
+    return null;
+  }
+
+  int getNumOfTasksToScheduleAndLog(float minFraction) {
+    int numTasksToSchedule = getNumOfTasksToSchedule(minFraction);
+    if (numTasksToSchedule > 0) {
+      // numTasksToSchedule can be -ve if minFraction
+      // is less than slowStartMinSrcCompletionFraction.
+      LOG.info("Scheduling {} tasks for vertex: {} with totalTasks: {}. " +
+          "{} source tasks completed out of {}. " +
+          "MinSourceTaskCompletedFraction: {} min: {} max: {}",
+          numTasksToSchedule, getContext().getVertexName(),
+          totalTasksToSchedule, numSourceTasksCompleted,
+          totalNumSourceTasks, minFraction,
+          slowStartMinFraction, slowStartMaxFraction);
+    }
+    return numTasksToSchedule;
+  }
+
+  int getNumOfTasksToSchedule(float minSourceVertexCompletedTaskFraction) {
+    int numPendingTasks = pendingTasks.size();
+    if (numSourceTasksCompleted == totalNumSourceTasks) {
+      LOG.info("All source tasks completed. Ramping up {} remaining tasks" +
+          " for vertex: {}", numPendingTasks, getContext().getVertexName());
+      return numPendingTasks;
+    }
+
+    // start scheduling when source tasks completed fraction is more than min.
+    // linearly increase the number of scheduled tasks such that all tasks are
+    // scheduled when source tasks completed fraction reaches max
+    float tasksFractionToSchedule = 1;
+    float percentRange =
+        slowStartMaxFraction - slowStartMinFraction;
+    if (percentRange > 0) {
+      tasksFractionToSchedule =
+          (minSourceVertexCompletedTaskFraction -
+              slowStartMinFraction) / percentRange;
+    } else {
+      // min and max are equal. schedule 100% on reaching min
+      if(minSourceVertexCompletedTaskFraction <
+          slowStartMinFraction) {
+        tasksFractionToSchedule = 0;
+      }
+    }
+
+    tasksFractionToSchedule =
+        Math.max(0, Math.min(1, tasksFractionToSchedule));
+
+    // round up to avoid the corner case that single task cannot be scheduled
+    // until src completed fraction reach max
+    return ((int)(Math.ceil(tasksFractionToSchedule * totalTasksToSchedule)) -
+        (totalTasksToSchedule - numPendingTasks));
+  }
+
+  /**
+   * Create a {@link VertexManagerPluginDescriptor} builder that can be used to
+   * configure the plugin.
+   *
+   * @param conf
+   *          {@link Configuration} May be modified in place. May be null if the
+   *          configuration parameters are to be set only via code. If
+   *          configuration values may be changed at runtime via a config file
+   *          then pass in a {@link Configuration} that is initialized from a
+   *          config file. The parameters that are not overridden in code will
+   *          be derived from the Configuration object.
+   * @return {@link RootInputVertexManagerConfigBuilder}
+   */
+  public static RootInputVertexManagerConfigBuilder createConfigBuilder(
+      @Nullable Configuration conf) {
+    return new RootInputVertexManagerConfigBuilder(conf);
+  }
+
+  /**
+   * Helper class to configure RootInputVertexManager
+   */
+  public static final class RootInputVertexManagerConfigBuilder {
+    private final Configuration conf;
+
+    private RootInputVertexManagerConfigBuilder(@Nullable Configuration conf) {
+      if (conf == null) {
+        this.conf = new Configuration(false);
+      } else {
+        this.conf = conf;
+      }
+    }
+
+    public RootInputVertexManagerConfigBuilder setSlowStart(
+        boolean enabled) {
+      conf.setBoolean(TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START,
+          enabled);
+      return this;
+    }
+
+    public RootInputVertexManagerConfigBuilder
+        setSlowStartMinSrcCompletionFraction(float minFraction) {
+      conf.setFloat(TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION,
+          minFraction);
+      return this;
+    }
+
+    public RootInputVertexManagerConfigBuilder
+        setSlowStartMaxSrcCompletionFraction(float maxFraction) {
+      conf.setFloat(TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION,
+          maxFraction);
+      return this;
+    }
+
+    public VertexManagerPluginDescriptor build() {
+      VertexManagerPluginDescriptor desc =
+          VertexManagerPluginDescriptor.create(
+              RootInputVertexManager.class.getName());
+
+      try {
+        return desc.setUserPayload(TezUtils
+            .createUserPayloadFromConf(this.conf));
+      } catch (IOException e) {
+        throw new TezUncheckedException(e);
+      }
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/b32b66f5/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index 30d65c4..4263094 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -2723,8 +2723,8 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex, EventHandl
       if (inputsWithInitializers != null) {
         LOG.info("Setting vertexManager to RootInputVertexManager for "
             + logIdentifier);
-        vertexManager = new VertexManager(
-            VertexManagerPluginDescriptor.create(RootInputVertexManager.class.getName()),
+        vertexManager = new VertexManager(RootInputVertexManager
+            .createConfigBuilder(vertexConf).build(),
             dagUgi, this, appContext, stateChangeNotifier);
       } else if (hasOneToOne && !hasCustom) {
         LOG.info("Setting vertexManager to InputReadyVertexManager for "

http://git-wip-us.apache.org/repos/asf/tez/blob/b32b66f5/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestRootInputVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestRootInputVertexManager.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestRootInputVertexManager.java
index 344a1db..50bac69 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestRootInputVertexManager.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestRootInputVertexManager.java
@@ -18,30 +18,60 @@
 
 package org.apache.tez.dag.app.dag.impl;
 
+import static org.apache.tez.dag.app.dag.impl.RootInputVertexManager.TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START;
+import static org.apache.tez.dag.app.dag.impl.RootInputVertexManager.TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION;
+import static org.apache.tez.dag.app.dag.impl.RootInputVertexManager.TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyList;
 import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
+import java.io.IOException;
+import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
+import org.apache.tez.dag.api.OutputDescriptor;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.UserPayload;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.apache.tez.runtime.api.TaskIdentifier;
+import org.apache.tez.runtime.api.VertexIdentifier;
+import org.apache.tez.runtime.api.VertexStatistics;
 import org.apache.tez.runtime.api.events.InputConfigureVertexTasksEvent;
 import org.apache.tez.runtime.api.events.InputDataInformationEvent;
+import org.junit.Assert;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 public class TestRootInputVertexManager {
 
+  List<TaskAttemptIdentifier> emptyCompletions = null;
+
   @Test(timeout = 5000)
-  public void testEventsFromMultipleInputs() {
+  public void testEventsFromMultipleInputs() throws IOException {
 
     VertexManagerPluginContext context = mock(VertexManagerPluginContext.class);
+    TezConfiguration conf = new TezConfiguration();
+    UserPayload vertexPayload = TezUtils.createUserPayloadFromConf(conf);
     doReturn("vertex1").when(context).getVertexName();
     doReturn(1).when(context).getVertexNumTasks(eq("vertex1"));
+    doReturn(vertexPayload).when(context).getUserPayload();
 
     RootInputVertexManager rootInputVertexManager = new RootInputVertexManager(context);
     rootInputVertexManager.initialize();
@@ -70,11 +100,14 @@ public class TestRootInputVertexManager {
   }
 
   @Test(timeout = 5000)
-  public void testConfigureFromMultipleInputs() {
+  public void testConfigureFromMultipleInputs() throws IOException {
 
     VertexManagerPluginContext context = mock(VertexManagerPluginContext.class);
+    TezConfiguration conf = new TezConfiguration();
+    UserPayload vertexPayload = TezUtils.createUserPayloadFromConf(conf);
     doReturn("vertex1").when(context).getVertexName();
     doReturn(-1).when(context).getVertexNumTasks(eq("vertex1"));
+    doReturn(vertexPayload).when(context).getUserPayload();
 
     RootInputVertexManager rootInputVertexManager = new RootInputVertexManager(context);
     rootInputVertexManager.initialize();
@@ -102,4 +135,426 @@ public class TestRootInputVertexManager {
     }
   }
 
+  @Test(timeout = 5000)
+  public void testRootInputVertexManagerSlowStart() {
+    Configuration conf = new Configuration();
+    RootInputVertexManager manager = null;
+    HashMap<String, EdgeProperty> mockInputVertices =
+        new HashMap<String, EdgeProperty>();
+    String mockSrcVertexId1 = "Vertex1";
+    EdgeProperty eProp1 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.BROADCAST,
+        EdgeProperty.DataSourceType.PERSISTED,
+        EdgeProperty.SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String mockSrcVertexId2 = "Vertex2";
+    EdgeProperty eProp2 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.BROADCAST,
+        EdgeProperty.DataSourceType.PERSISTED,
+        EdgeProperty.SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+
+    String mockManagedVertexId = "Vertex3";
+
+    VertexManagerPluginContext mockContext =
+        mock(VertexManagerPluginContext.class);
+    when(mockContext.getVertexStatistics(any(String.class)))
+        .thenReturn(mock(VertexStatistics.class));
+    when(mockContext.getInputVertexEdgeProperties())
+        .thenReturn(mockInputVertices);
+    when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(3);
+
+    mockInputVertices.put(mockSrcVertexId1, eProp1);
+    mockInputVertices.put(mockSrcVertexId2, eProp2);
+
+    // check initialization
+    manager = createRootInputVertexManager(conf, mockContext, 0.1f, 0.1f);
+
+    final List<Integer> scheduledTasks = Lists.newLinkedList();
+    doAnswer(new ScheduledTasksAnswer(scheduledTasks)).when(
+        mockContext).scheduleTasks(anyList());
+
+    // source vertices have 0 tasks. immediate start of all managed tasks
+    when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(0);
+    when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(0);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 3); // all tasks scheduled
+
+    when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(2);
+    when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(2);
+
+    try {
+      // source vertex have some tasks. min < 0.
+      manager = createRootInputVertexManager(conf, mockContext, -0.1f, 0.0f);
+      Assert.assertTrue(false); // should not come here
+    } catch (IllegalArgumentException e) {
+      Assert.assertTrue(e.getMessage().contains(
+          "Invalid values for slowStartMinFraction"));
+    }
+
+    try {
+      // source vertex have some tasks. max > 1.
+      manager = createRootInputVertexManager(conf, mockContext, 0.0f, 95.0f);
+      Assert.assertTrue(false); // should not come here
+    } catch (IllegalArgumentException e) {
+      Assert.assertTrue(e.getMessage().contains(
+          "Invalid values for slowStartMinFraction"));
+    }
+
+    try {
+      // source vertex have some tasks. min > max
+      manager = createRootInputVertexManager(conf, mockContext, 0.5f, 0.3f);
+      Assert.assertTrue(false); // should not come here
+    } catch (IllegalArgumentException e) {
+      Assert.assertTrue(e.getMessage().contains(
+          "Invalid values for slowStartMinFraction"));
+    }
+
+    // source vertex have some tasks. min > default and max undefined
+    int numTasks = 20;
+    when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(numTasks);
+    when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(numTasks);
+    scheduledTasks.clear();
+
+    manager = createRootInputVertexManager(conf, mockContext, 0.975f, null);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+
+    Assert.assertEquals(3, manager.pendingTasks.size());
+    Assert.assertEquals(numTasks*2, manager.totalNumSourceTasks);
+    Assert.assertEquals(0, manager.numSourceTasksCompleted);
+    float completedTasksThreshold = 0.975f * numTasks;
+    // Finish all tasks before exceeding the threshold
+    for (String mockSrcVertex : new String[] { mockSrcVertexId1,
+        mockSrcVertexId2 }) {
+      for (int i = 0; i < mockContext.getVertexNumTasks(mockSrcVertex); ++i) {
+        // complete 0th tasks outside the loop
+        manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+            mockSrcVertex, i+1));
+        if ((i + 2) >= completedTasksThreshold) {
+          // stop before completing more than min/max source tasks
+          break;
+        }
+      }
+    }
+    // Since we haven't exceeded the threshold, all tasks are still pending
+    Assert.assertEquals(manager.totalTasksToSchedule,
+        manager.pendingTasks.size());
+    Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
+
+    // Cross the threshold min/max threshold to schedule all tasks
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    Assert.assertEquals(3, manager.pendingTasks.size());
+    Assert.assertEquals(0, scheduledTasks.size());
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(0, manager.pendingTasks.size());
+    // all tasks scheduled
+    Assert.assertEquals(manager.totalTasksToSchedule, scheduledTasks.size());
+
+    // reset vertices for next test
+    when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(2);
+    when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(2);
+
+    // source vertex have some tasks. min, max, 0
+    manager = createRootInputVertexManager(conf, mockContext, 0.0f, 0.0f);
+    manager.onVertexStarted(emptyCompletions);
+    Assert.assertEquals(manager.totalTasksToSchedule, 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 0);
+    // all source vertices need to be configured
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 3); // all tasks scheduled
+
+    // min, max > 0 and min, max
+    manager = createRootInputVertexManager(conf, mockContext, 0.25f, 0.25f);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    // task completion from non-bipartite stage does nothing
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 0);
+    // task completion on only 1 SG edge does nothing
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 1);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 3); // all tasks scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 2);
+
+    // min, max > 0 and min, max, absolute max 1.0
+    manager = createRootInputVertexManager(conf, mockContext, 1.0f, 1.0f);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    // task completion from non-bipartite stage does nothing
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 0);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 1);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 2);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 3);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 3); // all tasks scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 4);
+
+    // min, max > 0 and min, max
+    manager = createRootInputVertexManager(conf, mockContext, 1.0f, 1.0f);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    // task completion from non-bipartite stage does nothing
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 4);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 0);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 1);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 2);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 3);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 3); // all tasks scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 4);
+
+    // reset vertices for next test
+    when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(4);
+    when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(4);
+
+    // min, max > and min < max
+    manager = createRootInputVertexManager(conf, mockContext, 0.25f, 0.75f);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 8);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 2);
+    // completion of same task again should not get counted
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 3);
+    Assert.assertEquals(manager.numSourceTasksCompleted, 2);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 1));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 1);
+    Assert.assertEquals(scheduledTasks.size(), 2); // 2 task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 4);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 2));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 2));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 1); // 1 tasks scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 6);
+    scheduledTasks.clear();
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 3)); // we are done. no action
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 0); // no task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 7);
+
+    // min, max > and min < max
+    manager = createRootInputVertexManager(conf, mockContext, 0.25f, 1.0f);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 3); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 8);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 1));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 1));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 2);
+    Assert.assertEquals(scheduledTasks.size(), 1); // 1 task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 4);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 2));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 2));
+    Assert.assertEquals(manager.pendingTasks.size(), 1);
+    Assert.assertEquals(scheduledTasks.size(), 1); // 1 task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 6);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 3));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 3));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 1); // no task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 8);
+
+    // if there is single task to schedule, it should be schedule when src
+    // completed fraction is more than min slow start fraction
+    scheduledTasks.clear();
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(1);
+    manager = createRootInputVertexManager(conf, mockContext, 0.25f, 0.75f);
+    manager.onVertexStarted(emptyCompletions);
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(mockSrcVertexId2,
+        VertexState.CONFIGURED));
+    Assert.assertEquals(manager.pendingTasks.size(), 1); // no tasks scheduled
+    Assert.assertEquals(manager.totalNumSourceTasks, 8);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 0));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 1));
+    Assert.assertEquals(manager.pendingTasks.size(), 1);
+    Assert.assertEquals(scheduledTasks.size(), 0); // no task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 2);
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 1));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 0));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 1); // 1 task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 4);
+    scheduledTasks.clear();
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 2));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId2, 2));
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 0); // no task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 6);
+    scheduledTasks.clear();
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(
+        mockSrcVertexId1, 3)); // we are done. no action
+    Assert.assertEquals(manager.pendingTasks.size(), 0);
+    Assert.assertEquals(scheduledTasks.size(), 0); // no task scheduled
+    Assert.assertEquals(manager.numSourceTasksCompleted, 7);
+  }
+
+
+  static RootInputVertexManager createRootInputVertexManager(
+      Configuration conf, VertexManagerPluginContext context, Float min,
+        Float max) {
+    if (min != null) {
+      conf.setFloat(
+          TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION,
+              min);
+    } else {
+      conf.unset(
+          TEZ_ROOT_INPUT_VERTEX_MANAGER_MIN_SRC_FRACTION);
+    }
+    if (max != null) {
+      conf.setFloat(
+          TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION,
+              max);
+    } else {
+      conf.unset(TEZ_ROOT_INPUT_VERTEX_MANAGER_MAX_SRC_FRACTION);
+    }
+    if(max != null || min != null) {
+      conf.setBoolean(TEZ_ROOT_INPUT_VERTEX_MANAGER_ENABLE_SLOW_START,
+          true);
+    }
+    UserPayload payload;
+    try {
+      payload = TezUtils.createUserPayloadFromConf(conf);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+    when(context.getUserPayload()).thenReturn(payload);
+    RootInputVertexManager manager = new RootInputVertexManager(context);
+    manager.initialize();
+    return manager;
+  }
+
+  protected static class ScheduledTasksAnswer implements Answer<Object> {
+    private List<Integer> scheduledTasks;
+    public ScheduledTasksAnswer(List<Integer> scheduledTasks) {
+      this.scheduledTasks = scheduledTasks;
+    }
+    @Override
+    public Object answer(InvocationOnMock invocation) throws IOException {
+      Object[] args = invocation.getArguments();
+      scheduledTasks.clear();
+      List<VertexManagerPluginContext.ScheduleTaskRequest> tasks =
+          (List<VertexManagerPluginContext.ScheduleTaskRequest>)args[0];
+      for (VertexManagerPluginContext.ScheduleTaskRequest task : tasks) {
+        scheduledTasks.add(task.getTaskIndex());
+      }
+      return null;
+    }
+  }
+
+  public static TaskAttemptIdentifier createTaskAttemptIdentifier(String vName,
+      int tId) {
+    VertexIdentifier mockVertex = mock(VertexIdentifier.class);
+    when(mockVertex.getName()).thenReturn(vName);
+    TaskIdentifier mockTask = mock(TaskIdentifier.class);
+    when(mockTask.getIdentifier()).thenReturn(tId);
+    when(mockTask.getVertexIdentifier()).thenReturn(mockVertex);
+    TaskAttemptIdentifier mockAttempt = mock(TaskAttemptIdentifier.class);
+    when(mockAttempt.getIdentifier()).thenReturn(0);
+    when(mockAttempt.getTaskIdentifier()).thenReturn(mockTask);
+    return mockAttempt;
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/b32b66f5/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
index b3dd60a..6eca322 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
@@ -6830,7 +6830,8 @@ public class TestVertexImpl {
   }
 
   @InterfaceAudience.Private
-  public static class VertexManagerWithException extends RootInputVertexManager{
+  public static class VertexManagerWithException extends
+      ImmediateStartVertexManager{
 
     public static enum VMExceptionLocation {
       NoExceptionDoReconfigure,

http://git-wip-us.apache.org/repos/asf/tez/blob/b32b66f5/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexManager.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexManager.java
index 6bec26e..3d9f271 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexManager.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexManager.java
@@ -42,6 +42,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.Callable;
 
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.yarn.event.EventHandler;
 import org.apache.tez.dag.api.InputDescriptor;
@@ -137,11 +138,11 @@ public class TestVertexManager {
   
   @Test(timeout = 5000)
   public void testOnRootVertexInitialized() throws Exception {
+    Configuration conf = new Configuration();
     VertexManager vm =
-        new VertexManager(
-            VertexManagerPluginDescriptor.create(RootInputVertexManager.class
-                .getName()), UserGroupInformation.getCurrentUser(), 
-                mockVertex, mockAppContext, mock(StateChangeNotifier.class));
+        new VertexManager(RootInputVertexManager.createConfigBuilder(conf)
+            .build(), UserGroupInformation.getCurrentUser(),
+            mockVertex, mockAppContext, mock(StateChangeNotifier.class));
     vm.initialize();
     InputDescriptor id1 = mock(InputDescriptor.class);
     List<Event> events1 = new LinkedList<Event>();

http://git-wip-us.apache.org/repos/asf/tez/blob/b32b66f5/tez-tests/src/test/java/org/apache/tez/test/TestExceptionPropagation.java
----------------------------------------------------------------------
diff --git a/tez-tests/src/test/java/org/apache/tez/test/TestExceptionPropagation.java b/tez-tests/src/test/java/org/apache/tez/test/TestExceptionPropagation.java
index fc1dab7..404e324 100644
--- a/tez-tests/src/test/java/org/apache/tez/test/TestExceptionPropagation.java
+++ b/tez-tests/src/test/java/org/apache/tez/test/TestExceptionPropagation.java
@@ -28,8 +28,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import com.google.protobuf.ByteString;
 import org.apache.commons.lang.StringUtils;
+import org.apache.tez.dag.app.dag.impl.ImmediateStartVertexManager;
 import org.apache.tez.dag.app.dag.impl.OneToOneEdgeManagerOnDemand;
+import org.apache.tez.dag.app.dag.impl.RootInputVertexManager;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.apache.hadoop.conf.Configuration;
@@ -65,7 +68,6 @@ import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
 import org.apache.tez.dag.api.client.DAGClient;
 import org.apache.tez.dag.api.client.DAGStatus;
 import org.apache.tez.dag.api.event.VertexStateUpdate;
-import org.apache.tez.dag.app.dag.impl.RootInputVertexManager;
 import org.apache.tez.dag.library.vertexmanager.InputReadyVertexManager;
 import org.apache.tez.runtime.api.AbstractLogicalIOProcessor;
 import org.apache.tez.runtime.api.AbstractLogicalInput;
@@ -344,7 +346,8 @@ public class TestExceptionPropagation {
         InputInitializerWithException.getIIDesc(payload);
     v1.addDataSource("input",
         DataSourceDescriptor.create(inputDesc, iiDesc, null));
-    v1.setVertexManagerPlugin(RootInputVertexManagerWithException.getVMDesc(payload));
+    v1.setVertexManagerPlugin(RootInputVertexManagerWithException
+        .getVMDesc(exLocation));
 
     Vertex v2 = 
         Vertex.create("v2", DoNothingProcessor.getProcDesc(), 1);
@@ -672,6 +675,8 @@ public class TestExceptionPropagation {
   public static class RootInputVertexManagerWithException extends RootInputVertexManager {
 
     private ExceptionLocation exLocation;
+    private static final String Test_ExceptionLocation =
+        "Test.ExceptionLocation";
 
     public RootInputVertexManagerWithException(VertexManagerPluginContext context) {
       super(context);
@@ -680,9 +685,15 @@ public class TestExceptionPropagation {
     @Override
     public void initialize() {
       super.initialize();
-      this.exLocation =
-          ExceptionLocation.valueOf(new String(getContext().getUserPayload()
-              .deepCopyAsArray()));
+      Configuration conf;
+      try {
+        conf = TezUtils.createConfFromUserPayload(
+            getContext().getUserPayload());
+        this.exLocation = ExceptionLocation.valueOf(
+            conf.get(Test_ExceptionLocation));
+      } catch (IOException e) {
+        throw new TezUncheckedException(e);
+      }
       if (this.exLocation == ExceptionLocation.VM_INITIALIZE) {
         throw new RuntimeException(this.exLocation.name());
       }
@@ -705,9 +716,22 @@ public class TestExceptionPropagation {
       super.onVertexStarted(completions);
     }
 
-    public static VertexManagerPluginDescriptor getVMDesc(UserPayload payload) {
-      return VertexManagerPluginDescriptor.create(RootInputVertexManagerWithException.class.getName())
-              .setUserPayload(payload);
+    @Override
+    public void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
+      if (this.exLocation == ExceptionLocation.VM_ON_SOURCETASK_COMPLETED) {
+        throw new RuntimeException(this.exLocation.name());
+      }
+      super.onSourceTaskCompleted(attempt);
+    }
+
+    public static VertexManagerPluginDescriptor getVMDesc(
+        ExceptionLocation exLocation) throws IOException {
+      Configuration conf = new Configuration();
+      conf.set(Test_ExceptionLocation, exLocation.name());
+      UserPayload payload = TezUtils.createUserPayloadFromConf(conf);
+      return VertexManagerPluginDescriptor.create(
+          RootInputVertexManagerWithException.class.getName())
+          .setUserPayload(payload);
     }
   }
 


Mime
View raw message