tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From min...@apache.org
Subject [3/3] tez git commit: TEZ-3269. Provide basic fair routing and scheduling functionality via custom VertexManager and EdgeManager.
Date Sat, 12 Nov 2016 16:35:52 GMT
TEZ-3269. Provide basic fair routing and scheduling functionality via custom VertexManager and EdgeManager.


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/2c4ef9fe
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/2c4ef9fe
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/2c4ef9fe

Branch: refs/heads/master
Commit: 2c4ef9fe58395aa8e835a4c50cb65bbb26428638
Parents: 0d59844
Author: Ming Ma <mingma@twitter.com>
Authored: Sat Nov 12 08:35:28 2016 -0800
Committer: Ming Ma <mingma@twitter.com>
Committed: Sat Nov 12 08:35:28 2016 -0800

----------------------------------------------------------------------
 CHANGES.txt                                     |    1 +
 tez-runtime-library/findbugs-exclude.xml        |   42 +
 tez-runtime-library/pom.xml                     |    1 +
 .../DestinationTaskInputsProperty.java          |   92 ++
 .../vertexmanager/FairEdgeConfiguration.java    |  111 ++
 .../vertexmanager/FairShuffleEdgeManager.java   |  154 ++
 .../vertexmanager/FairShuffleVertexManager.java |  631 ++++++++
 .../vertexmanager/ShuffleVertexManager.java     |   37 +-
 .../vertexmanager/ShuffleVertexManagerBase.java |  135 +-
 .../src/main/proto/FairShufflePayloads.proto    |   37 +
 .../TestFairShuffleVertexManager.java           |  347 +++++
 .../vertexmanager/TestShuffleVertexManager.java | 1424 ++----------------
 .../TestShuffleVertexManagerBase.java           | 1115 ++++++++++++++
 .../TestShuffleVertexManagerUtils.java          |  346 +++++
 14 files changed, 3080 insertions(+), 1393 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 8128c7b..0948862 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -7,6 +7,7 @@ INCOMPATIBLE CHANGES
 
 ALL CHANGES:
 
+  TEZ-3269. Provide basic fair routing and scheduling functionality via custom VertexManager and EdgeManager.
   TEZ-3534. Differentiate thread names on Fetchers, minor changes to shuffle shutdown code.
   TEZ-3491. Tez job can hang due to container priority inversion.
   TEZ-3533. ShuffleScheduler should shutdown threadpool on exit.

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/findbugs-exclude.xml
----------------------------------------------------------------------
diff --git a/tez-runtime-library/findbugs-exclude.xml b/tez-runtime-library/findbugs-exclude.xml
index d3b6245..da7a013 100644
--- a/tez-runtime-library/findbugs-exclude.xml
+++ b/tez-runtime-library/findbugs-exclude.xml
@@ -152,6 +152,7 @@
   <Match>
     <Class name="org.apache.tez.dag.library.vertexmanager.ShuffleVertexManagerBase"/>
     <Or>
+      <Field name="bipartiteSources"/>
       <Field name="numBipartiteSourceTasksCompleted"/>
       <Field name="totalNumBipartiteSourceTasks"/>
       <Field name="totalTasksToSchedule"/>
@@ -159,4 +160,45 @@
     <Bug pattern="IS2_INCONSISTENT_SYNC"/>
   </Match>
 
+  <Match>
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerConfigPayloadProto"/>
+    <Field name="unknownFields"/>
+    <Bug pattern="SE_BAD_FIELD"/>
+  </Match>
+
+  <Match>
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerDestinationTaskPropProto"/>
+    <Field name="unknownFields"/>
+    <Bug pattern="SE_BAD_FIELD"/>
+  </Match>
+
+  <Match>
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$RangeProto"/>
+    <Field name="unknownFields"/>
+    <Bug pattern="SE_BAD_FIELD"/>
+  </Match>
+
+  <Match>
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerConfigPayloadProto"/>
+    <Field name="PARSER"/>
+    <Bug pattern="MS_SHOULD_BE_FINAL"/>
+  </Match>
+
+  <Match> 
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$FairShuffleEdgeManagerDestinationTaskPropProto"/>
+    <Field name="PARSER"/>
+    <Bug pattern="MS_SHOULD_BE_FINAL"/>
+  </Match>
+
+  <Match>   
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$RangeProto"/>
+    <Field name="PARSER"/>
+    <Bug pattern="MS_SHOULD_BE_FINAL"/>
+  </Match>
+
+  <Match>
+    <Class name="org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads$RangeProto$Builder"/>
+    <Method name="maybeForceBuilderInitialization"/>
+    <Bug pattern="UCF_USELESS_CONTROL_FLOW"/>
+  </Match>
 </FindBugsFilter>

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/pom.xml
----------------------------------------------------------------------
diff --git a/tez-runtime-library/pom.xml b/tez-runtime-library/pom.xml
index b676933..2ccd65f 100644
--- a/tez-runtime-library/pom.xml
+++ b/tez-runtime-library/pom.xml
@@ -130,6 +130,7 @@
                 <includes>
                   <include>ShufflePayloads.proto</include>
                   <include>CartesianProductPayload.proto</include>
+                  <include>FairShufflePayloads.proto</include>
                 </includes>
               </source>
               <output>${project.build.directory}/generated-sources/java</output>

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java
new file mode 100644
index 0000000..bb23f19
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/DestinationTaskInputsProperty.java
@@ -0,0 +1,92 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.dag.library.vertexmanager;
+
+// Each destination task fetches data from numOfSourceTasks of consecutive
+// source tasks with the first source task index being firstSourceTaskIndex.
+// For any source task in that range, each destination task fetches
+// numOfPartitions consecutive physical outputs with the first physical output
+// index being firstPartitionId.
+class DestinationTaskInputsProperty {
+  private final int firstPartitionId;
+  private final int numOfPartitions;
+  private final int firstSourceTaskIndex;
+  private final int numOfSourceTasks;
+  public DestinationTaskInputsProperty(int firstPartitionId,
+      int numOfPartitions, int firstSourceTaskIndex, int numOfSourceTasks) {
+    this.firstPartitionId = firstPartitionId;
+    this.numOfPartitions = numOfPartitions;
+    this.firstSourceTaskIndex = firstSourceTaskIndex;
+    this.numOfSourceTasks = numOfSourceTasks;
+  }
+  public int getFirstPartitionId() {
+    return firstPartitionId;
+  }
+  public int getNumOfPartitions() {
+    return numOfPartitions;
+  }
+  public int getFirstSourceTaskIndex() {
+    return firstSourceTaskIndex;
+  }
+  public int getNumOfSourceTasks() {
+    return numOfSourceTasks;
+  }
+  public boolean isSourceTaskInRange(int sourceTaskIndex) {
+    return firstSourceTaskIndex <= sourceTaskIndex &&
+        sourceTaskIndex < firstSourceTaskIndex +
+            numOfSourceTasks;
+  }
+  public boolean isPartitionInRange(int partitionId) {
+    return firstPartitionId <= partitionId &&
+        partitionId < firstPartitionId + numOfPartitions;
+  }
+
+  // The first physical input index for the source task
+  public int getFirstPhysicalInputIndex(int sourceTaskIndex) {
+    return getPhysicalInputIndex(sourceTaskIndex, firstPartitionId);
+  }
+
+  // The physical input index for the physical output index of the source task
+  public int getPhysicalInputIndex(int sourceTaskIndex, int partitionId) {
+    if (isSourceTaskInRange(sourceTaskIndex) &&
+        isPartitionInRange(partitionId)) {
+      return (sourceTaskIndex - firstSourceTaskIndex) * numOfPartitions +
+          (partitionId - firstPartitionId);
+    } else {
+      return -1;
+    }
+  }
+
+  public int getNumOfPhysicalInputs() {
+    return numOfPartitions * numOfSourceTasks;
+  }
+
+  public int getSourceTaskIndex(int physicalInputIndex) {
+    return firstSourceTaskIndex + physicalInputIndex / numOfPartitions;
+  }
+
+  @Override
+  public String toString() {
+    return "firstPartitionId = " + firstPartitionId +
+        " ,numOfPartitions = " + numOfPartitions +
+        " ,firstSourceTaskIndex = " + firstSourceTaskIndex +
+        " ,numOfSourceTasks = " + numOfSourceTasks;
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java
new file mode 100644
index 0000000..846e0a3
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairEdgeConfiguration.java
@@ -0,0 +1,111 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.dag.library.vertexmanager;
+
+import com.google.protobuf.ByteString;
+import com.google.protobuf.InvalidProtocolBufferException;
+
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.FairShuffleEdgeManagerConfigPayloadProto;
+import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.FairShuffleEdgeManagerDestinationTaskPropProto;
+import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.RangeProto;
+
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.Map.Entry;
+
+
+/**
+ * Handles edge configuration serialization and de-serialization between
+ * {@link FairShuffleVertexManager} and {@link FairShuffleEdgeManager}.
+ */
+class FairEdgeConfiguration {
+  private final int numBuckets;
+  private final HashMap<Integer, DestinationTaskInputsProperty>
+      destinationInputsProperties;
+
+  public FairEdgeConfiguration(int numBuckets,
+      HashMap<Integer, DestinationTaskInputsProperty> routingTable) {
+    this.destinationInputsProperties = routingTable;
+    this.numBuckets = numBuckets;
+  }
+
+  private FairShuffleEdgeManagerConfigPayloadProto getConfigPayload() {
+    FairShuffleEdgeManagerConfigPayloadProto.Builder builder =
+        FairShuffleEdgeManagerConfigPayloadProto.newBuilder();
+    builder.setNumBuckets(numBuckets);
+    if (destinationInputsProperties != null) {
+      for (Entry<Integer, DestinationTaskInputsProperty> entry :
+          destinationInputsProperties.entrySet()) {
+        FairShuffleEdgeManagerDestinationTaskPropProto.Builder taskBuilder =
+            FairShuffleEdgeManagerDestinationTaskPropProto.newBuilder();
+        taskBuilder.
+            setDestinationTaskIndex(entry.getKey()).
+            setPartitions(newRange(entry.getValue().getFirstPartitionId(),
+            entry.getValue().getNumOfPartitions())).
+            setSourceTasks(newRange(entry.getValue().
+            getFirstSourceTaskIndex(), entry.getValue().getNumOfSourceTasks()));
+        builder.addDestinationTaskProps(taskBuilder.build());
+      }
+    }
+    return builder.build();
+  }
+
+  private RangeProto newRange(int firstIndex, int numOfIndexes) {
+    return RangeProto.newBuilder().
+        setFirstIndex(firstIndex).setNumOfIndexes(numOfIndexes).build();
+  }
+
+  static FairEdgeConfiguration fromUserPayload(UserPayload payload)
+      throws InvalidProtocolBufferException {
+    HashMap<Integer, DestinationTaskInputsProperty> routingTable = new HashMap<>();
+    FairShuffleEdgeManagerConfigPayloadProto proto =
+        FairShuffleEdgeManagerConfigPayloadProto.parseFrom(
+            ByteString.copyFrom(payload.getPayload()));
+    int numBuckets = proto.getNumBuckets();
+    if (proto.getDestinationTaskPropsList() != null) {
+      for (int i = 0; i < proto.getDestinationTaskPropsList().size(); i++) {
+        FairShuffleEdgeManagerDestinationTaskPropProto propProto =
+            proto.getDestinationTaskPropsList().get(i);
+        routingTable.put(
+            propProto.getDestinationTaskIndex(),
+            new DestinationTaskInputsProperty(
+                propProto.getPartitions().getFirstIndex(),
+                propProto.getPartitions().getNumOfIndexes(),
+                propProto.getSourceTasks().getFirstIndex(),
+                propProto.getSourceTasks().getNumOfIndexes()));
+      }
+    }
+    return new FairEdgeConfiguration(numBuckets, routingTable);
+  }
+
+  public HashMap<Integer, DestinationTaskInputsProperty> getRoutingTable() {
+    return destinationInputsProperties;
+  }
+
+  // The number of partitions used by source vertex.
+  int getNumBuckets() {
+    return numBuckets;
+  }
+
+  UserPayload getBytePayload() {
+    return UserPayload.create(ByteBuffer.wrap(
+        getConfigPayload().toByteArray()));
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java
new file mode 100644
index 0000000..ff1c032
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleEdgeManager.java
@@ -0,0 +1,154 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.dag.library.vertexmanager;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+
+import org.apache.tez.dag.api.EdgeManagerPluginContext;
+import org.apache.tez.dag.api.EdgeManagerPluginOnDemand;
+import org.apache.tez.dag.api.UserPayload;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+
+/**
+ * Edge manager for fair routing. Each destination task has its
+ * DestinationTaskInputsProperty used to decide how to do event routing
+ * between source and destination.
+ */
+public class FairShuffleEdgeManager extends EdgeManagerPluginOnDemand {
+
+  private FairEdgeConfiguration conf = null;
+  // The key in the mapping is the destination task index.
+  // The value in the mapping is DestinationTaskInputsProperty of the
+  // destination task.
+  private HashMap<Integer, DestinationTaskInputsProperty> mapping;
+
+  // used by the framework at runtime. initialize is the real initializer at runtime
+  public FairShuffleEdgeManager(EdgeManagerPluginContext context) {
+    super(context);
+  }
+
+  @Override
+  public int getNumDestinationTaskPhysicalInputs(int destTaskIndex) {
+    return mapping.get(destTaskIndex).getNumOfPhysicalInputs();
+  }
+
+  @Override
+  public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex) {
+    return conf.getNumBuckets();
+  }
+
+  @Override
+  public int getNumDestinationConsumerTasks(int sourceTaskIndex) {
+    int numTasks = 0;
+    for(DestinationTaskInputsProperty entry: mapping.values()) {
+      if (entry.isSourceTaskInRange(sourceTaskIndex)) {
+        numTasks++;
+      }
+    }
+    return numTasks;
+  }
+
+  // called at runtime to initialize the custom edge.
+  @Override
+  public void initialize() {
+    UserPayload userPayload = getContext().getUserPayload();
+    if (userPayload == null || userPayload.getPayload() == null ||
+        userPayload.getPayload().limit() == 0) {
+      throw new RuntimeException("Could not initialize FairShuffleEdgeManager"
+          + " from provided user payload");
+    }
+    try {
+      conf = FairEdgeConfiguration.fromUserPayload(userPayload);
+      mapping = conf.getRoutingTable();
+    } catch (InvalidProtocolBufferException e) {
+      throw new RuntimeException("Could not initialize FairShuffleEdgeManager"
+          + " from provided user payload", e);
+    }
+  }
+
+  @Override
+  public int routeInputErrorEventToSource(int destinationTaskIndex,
+      int destinationFailedInputIndex) {
+    return mapping.get(destinationTaskIndex).getSourceTaskIndex(
+        destinationFailedInputIndex);
+  }
+
+  @Override
+  public void prepareForRouting() throws Exception {
+  }
+
+  @Override
+  public EventRouteMetadata routeDataMovementEventToDestination(
+      int sourceTaskIndex, int sourceOutputIndex, int destTaskIndex)
+      throws Exception {
+    DestinationTaskInputsProperty property = mapping.get(destTaskIndex);
+    int targetIndex = property.getPhysicalInputIndex(sourceTaskIndex,
+        sourceOutputIndex);
+    if (targetIndex != -1) {
+      return EventRouteMetadata.create(1, new int[]{targetIndex});
+    } else {
+      return null;
+    }
+  }
+
+  // Create an array of "count" consecutive integers with starting
+  // value equal to "startValue".
+  private int[] getRange(int startValue, int count) {
+    int[] values = new int[count];
+    for (int i = 0; i < count; i++) {
+      values[i] = startValue + i;
+    }
+    return values;
+  }
+
+  @Override
+  public @Nullable EventRouteMetadata
+      routeCompositeDataMovementEventToDestination(int sourceTaskIndex,
+      int destinationTaskIndex) {
+    DestinationTaskInputsProperty property = mapping.get(destinationTaskIndex);
+    int firstPhysicalInputIndex =
+        property.getFirstPhysicalInputIndex(sourceTaskIndex);
+    if (firstPhysicalInputIndex >= 0) {
+      return EventRouteMetadata.create(property.getNumOfPartitions(),
+          getRange(firstPhysicalInputIndex, property.getNumOfPartitions()),
+          getRange(property.getFirstPartitionId(),
+          property.getNumOfPartitions()));
+    } else {
+      return null;
+    }
+  }
+
+  @Override
+  public EventRouteMetadata routeInputSourceTaskFailedEventToDestination(
+      int sourceTaskIndex, int destinationTaskIndex) throws Exception {
+    DestinationTaskInputsProperty property = mapping.get(destinationTaskIndex);
+    int firstPhysicalInputIndex =
+        property.getFirstPhysicalInputIndex(sourceTaskIndex);
+    if (firstPhysicalInputIndex >= 0) {
+      return EventRouteMetadata.create(property.getNumOfPartitions(),
+          getRange(firstPhysicalInputIndex, property.getNumOfPartitions()));
+    } else {
+      return null;
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java
new file mode 100644
index 0000000..a8b336c
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/FairShuffleVertexManager.java
@@ -0,0 +1,631 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.dag.library.vertexmanager;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.UnmodifiableIterator;
+
+import com.google.common.primitives.Ints;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
+import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.classification.InterfaceStability.Evolving;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.math.BigInteger;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+
+
+/**
+ * Fair routing based on partition size distribution to achieve optimal
+ * input size for any destination task thus reduce data skewness.
+ * By default the feature is turned off and it supports the regular shuffle like
+ * ShuffleVertexManager.
+ * When the feature is turned on, there are two routing types as defined in
+ * {@link FairRoutingType}. One is {@link FairRoutingType#REDUCE_PARALLELISM}
+ * which is similar to ShuffleVertexManager's auto reduce functionality.
+ * Another one is {@link FairRoutingType#FAIR_PARALLELISM} where each
+ * destination task can process a range of consecutive partitions from a range
+ * of consecutive source tasks.
+ */
+@Public
+@Evolving
+public class FairShuffleVertexManager extends ShuffleVertexManagerBase {
+
+  private static final Logger LOG =
+      LoggerFactory.getLogger(FairShuffleVertexManager.class);
+
+  /**
+   * The desired size of input per task. Parallelism will be changed to meet
+   * this criteria.
+   */
+  public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE =
+      "tez.fair-shuffle-vertex-manager.desired-task-input-size";
+  public static final long
+      TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT = 100 * MB;
+
+  /**
+   * Enables automatic parallelism determination for the vertex. Based on input data
+   * statistics the parallelism is adjusted to a desired level.
+   */
+  public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL =
+      "tez.fair-shuffle-vertex-manager.enable.auto-parallel";
+  public static final String
+      TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT =
+          FairRoutingType.NONE.getType();
+
+  /**
+   * In case of a ScatterGather connection, the fraction of source tasks which
+   * should complete before tasks for the current vertex are scheduled
+   */
+  public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION =
+      "tez.fair-shuffle-vertex-manager.min-src-fraction";
+  public static final float TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT = 0.25f;
+
+  /**
+   * In case of a ScatterGather connection, once this fraction of source tasks
+   * have completed, all tasks on the current vertex can be scheduled. Number of
+   * tasks ready for scheduling on the current vertex scales linearly between
+   * min-fraction and max-fraction. Defaults to the greater of the default value
+   * or tez.fair-shuffle-vertex-manager.min-src-fraction.
+   */
+  public static final String TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION =
+      "tez.fair-shuffle-vertex-manager.max-src-fraction";
+  public static final float TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT = 0.75f;
+
+  /**
+   * Enables automatic parallelism determination for the vertex. Based on input data
+   * statistics the parallelism is adjusted to a desired level.
+   */
+  public enum FairRoutingType {
+    /**
+     * Don't do any fair routing.
+     */
+    NONE("none"),
+
+    /**
+     * TEZ-2962 Based on input data statistics the parallelism is decreased
+     * to a desired level by having one destination task process multiple
+     * consecutive partitions.
+     */
+    REDUCE_PARALLELISM("reduce_parallelism"),
+
+    /**
+     * Based on input data statistics the parallelism is adjusted
+     * to a desired level by having one destination task process multiple
+     * small partitions and multiple destination tasks process one
+     * large partition. Only works when there is one bipartite edge.
+     */
+    FAIR_PARALLELISM("fair_parallelism");
+
+    private final String type;
+
+    private FairRoutingType(String type) {
+      this.type = type;
+    }
+
+    public final String getType() {
+      return type;
+    }
+
+    public boolean reduceParallelismEnabled() {
+      return equals(FairRoutingType.REDUCE_PARALLELISM);
+    }
+
+    public boolean fairParallelismEnabled() {
+      return equals(FairRoutingType.FAIR_PARALLELISM);
+    }
+
+    public boolean enabled() {
+      return !equals(FairRoutingType.NONE);
+    }
+
+    public static FairRoutingType fromString(String type) {
+      if (type != null) {
+        for (FairRoutingType b : FairRoutingType.values()) {
+          if (type.equalsIgnoreCase(b.type)) {
+            return b;
+          }
+        }
+      }
+      throw new IllegalArgumentException("Invalid type " + type);
+    }
+  }
+
+  static class FairSourceVertexInfo extends SourceVertexInfo {
+    // mapping from destination task id to DestinationTaskInputsProperty
+    private final HashMap<Integer, DestinationTaskInputsProperty>
+        destinationInputsProperties = new HashMap<>();
+
+    FairSourceVertexInfo(final EdgeProperty edgeProperty,
+        int totalTasksToSchedule) {
+      super(edgeProperty, totalTasksToSchedule);
+    }
+    public HashMap<Integer, DestinationTaskInputsProperty>
+        getDestinationInputsProperties() {
+      return destinationInputsProperties;
+    }
+  }
+
+  @Override
+  SourceVertexInfo createSourceVertexInfo(EdgeProperty edgeProperty,
+      int numTasks) {
+    return new FairSourceVertexInfo(edgeProperty, numTasks);
+  }
+
+
+  FairShuffleVertexManagerConfig mgrConfig;
+
+  public FairShuffleVertexManager(VertexManagerPluginContext context) {
+    super(context);
+  }
+
+  @Override
+  protected void onVertexStartedCheck() {
+    super.onVertexStartedCheck();
+    if (bipartiteSources > 1 &&
+        (mgrConfig.getFairRoutingType().fairParallelismEnabled())) {
+      // TODO TEZ-3500
+      throw new TezUncheckedException(
+          "Having more than one destination task process same partition(s) " +
+              "only works with one bipartite source.");
+    }
+  }
+
+  static long ceil(long a, long b) {
+    return (a + (b - 1)) / b;
+  }
+
+  public long[] estimatePartitionSize() {
+    boolean partitionStatsReported = false;
+    int numOfPartitions = pendingTasks.size();
+    long[] estimatedPartitionOutputSize = new long[numOfPartitions];
+    for (int i = 0; i < numOfPartitions; i++) {
+      if (getCurrentlyKnownStatsAtIndex(i) > 0) {
+        partitionStatsReported = true;
+        break;
+      }
+    }
+
+    if (!partitionStatsReported) {
+      // partition stats reporting isn't enabled at the source. Use
+      // expected source output size and assume all partitions are evenly
+      // distributed.
+      if (numOfPartitions > 0) {
+        long estimatedPerPartitionSize =
+                getExpectedTotalBipartiteSourceTasksOutputSize().divide(
+                        BigInteger.valueOf(numOfPartitions)).longValue();
+        for (int i = 0; i < numOfPartitions; i++) {
+          estimatedPartitionOutputSize[i] = estimatedPerPartitionSize;
+        }
+      }
+    } else {
+      for (int i = 0; i < numOfPartitions; i++) {
+        estimatedPartitionOutputSize[i] =
+            MB * getExpectedStatsAtIndex(i);
+      }
+    }
+    return estimatedPartitionOutputSize;
+  }
+
+  /*
+   * The class calculates how partitions and source tasks should be
+   * grouped together. It allows a destination task to fetch a consecutive
+   * range of partitions from a consecutive range of source tasks to achieve
+   * optimal physical input size specified by desiredTaskInputDataSize.
+   * First it estimates the size of each partition at job completion based
+   * on the partition and output size of the completed tasks. The estimation
+   * is stored in estimatedPartitionOutputSize.
+   * Then it walks the partitions starting from beginning.
+   * If a partition is not greater than desiredTaskInputDataSize, it keeps
+   * accumulating the next partition until it is about to exceed
+   * desiredTaskInputDataSize. Then it will create a new destination task to
+   * fetch these small partitions in the range of
+   * {firstPartitionId, numOfPartitions} to from all source tasks.
+   * If a partition is larger than desiredTaskInputDataSize,
+   * For FairRoutingType.REDUCE policy, it creates a new destination task to
+   * to fetch this large partition from all source tasks.
+   * For FairRoutingType.FAIR policy, it will create multiple destination tasks
+   * each of which will fetch the large partition from a range
+   * of source tasks.
+   */
+  private class PartitionsGroupingCalculator
+      implements Iterable<DestinationTaskInputsProperty> {
+
+    private final FairSourceVertexInfo sourceVertexInfo;
+
+    // Estimated aggregated partition output size when the job is done.
+    private long[] estimatedPartitionOutputSize;
+
+    // Intermediate states used to group partitions.
+
+    // Total output size of partitions in current group.
+    private long sizeOfPartitions = 0;
+    // Total number of partitions in the current group.
+    private int numOfPartitions = 0;
+    // The first partition id in the current group.
+    private int firstPartitionId = 0;
+    // The # of source tasks a destination task consumes.
+    // When FAIR_PARALLELISM is enabled, there will be multiple destination
+    // tasks processing the same partition and each destination task will
+    // process a range of source tasks of that partition. For a given
+    // partition, the number of source tasks assigned to different destination
+    // tasks should differ by one at most and numOfBaseSourceTasks is the
+    // smaller value. numOfBaseDestinationTasks is the number of destination tasks that
+    // process numOfBaseSourceTasks source tasks.
+    // e.g. if 8 source tasks are assigned 3 destination tasks, the number of
+    // source tasks assigned to these 3 destination tasks are {2, 3, 3}.
+    // numOfBaseDestinationTasks == 1, numOfBaseSourceTasks == 2.
+    private int numOfBaseSourceTasks = 0;
+    private int numOfBaseDestinationTasks = 0;
+    public PartitionsGroupingCalculator(long[] estimatedPartitionOutputSize,
+        FairSourceVertexInfo sourceVertexInfo) {
+      this.estimatedPartitionOutputSize = estimatedPartitionOutputSize;
+      this.sourceVertexInfo = sourceVertexInfo;
+    }
+
+    // Start the processing of the next group of partitions
+    private void startNextPartitionsGroup() {
+      this.firstPartitionId += this.numOfPartitions;
+      this.sizeOfPartitions = 0;
+      this.numOfPartitions = 0;
+      this.numOfBaseSourceTasks = 0;
+      this.numOfBaseDestinationTasks = 0;
+    }
+
+    private int getNextPartitionId() {
+      return this.firstPartitionId + this.numOfPartitions;
+    }
+
+    private void addNextPartition() {
+      if (hasPartitionsLeft()) {
+        this.sizeOfPartitions +=
+            estimatedPartitionOutputSize[getNextPartitionId()];
+        this.numOfPartitions++;
+      }
+    }
+
+    private boolean hasPartitionsLeft() {
+      return getNextPartitionId() < this.estimatedPartitionOutputSize.length;
+    }
+
+    private long getCurrentAndNextPartitionSize() {
+      return hasPartitionsLeft() ? this.sizeOfPartitions +
+          estimatedPartitionOutputSize[getNextPartitionId()] :
+          this.sizeOfPartitions;
+    }
+
+    // For the current source output partition(s), decide how
+    // source tasks should be grouped.
+    private boolean computeSourceTasksGrouping() {
+      boolean finalizeCurrentPartitions = true;
+      int groupCount = Ints.checkedCast(ceil(getCurrentAndNextPartitionSize(),
+          config.getDesiredTaskInputDataSize()));
+      if (groupCount <= 1) {
+        // There is no enough data so far to reach desiredTaskInputDataSize.
+        addNextPartition();
+        if (!hasPartitionsLeft()) {
+          // We have reached the last partition.
+          // Consume from all source tasks.
+          this.numOfBaseDestinationTasks = 1;
+          this.numOfBaseSourceTasks = this.sourceVertexInfo.numTasks;
+        } else {
+          finalizeCurrentPartitions = false;
+        }
+      } else if (numOfPartitions == 0) {
+        // The first partition in the current group exceeds
+        // desiredTaskInputDataSize.
+        addNextPartition();
+        if (mgrConfig.getFairRoutingType().reduceParallelismEnabled()) {
+          // Consume from all source tasks
+          this.numOfBaseDestinationTasks = 1;
+          this.numOfBaseSourceTasks = this.sourceVertexInfo.numTasks;
+        } else {
+          // When groupCount > sourceVertexInfo.numTasks, it means
+          // sizeOfPartitions is too big so that even if
+          // we just have one destination task fetch from one source task the
+          // input size still exceeds desiredTaskInputDataSize.
+          if ((this.sourceVertexInfo.numTasks >= groupCount)) {
+            this.numOfBaseDestinationTasks = groupCount -
+                this.sourceVertexInfo.numTasks % groupCount;
+            this.numOfBaseSourceTasks =
+                this.sourceVertexInfo.numTasks / groupCount;
+          } else {
+            this.numOfBaseDestinationTasks = this.sourceVertexInfo.numTasks;
+            this.numOfBaseSourceTasks = 1;
+          }
+        }
+      } else {
+        // There are existing partitions in the current group. Adding the next
+        // partition causes the total size to exceed desiredTaskInputDataSize.
+        // Let us process the existing partitions in the current group. The
+        // next partition will be processed in the next group.
+        this.numOfBaseDestinationTasks = 1;
+        this.numOfBaseSourceTasks = this.sourceVertexInfo.numTasks;
+      }
+      return finalizeCurrentPartitions;
+    }
+
+    @Override
+    public Iterator<DestinationTaskInputsProperty> iterator() {
+      return new UnmodifiableIterator<DestinationTaskInputsProperty>() {
+        private int j = 0;
+        private boolean visitedAtLeastOnce = false;
+        private int groupIndex = 0;
+
+        // Get number of source tasks in the current group.
+        private int getNumOfSourceTasks() {
+          return groupIndex++ < numOfBaseDestinationTasks ?
+              numOfBaseSourceTasks : numOfBaseSourceTasks + 1;
+        }
+
+        @Override
+        public boolean hasNext() {
+          return j < sourceVertexInfo.numTasks || !visitedAtLeastOnce;
+        }
+
+        @Override
+        public DestinationTaskInputsProperty next() {
+          if (hasNext()) {
+            visitedAtLeastOnce = true;
+            int start = j;
+            int numOfSourceTasks = getNumOfSourceTasks();
+            j += numOfSourceTasks;
+            return new DestinationTaskInputsProperty(firstPartitionId,
+                numOfPartitions, start, numOfSourceTasks);
+          }
+          throw new NoSuchElementException();
+        }
+      };
+    }
+
+    public void compute() {
+      int destinationIndex = 0;
+      while (hasPartitionsLeft()) {
+        if (!computeSourceTasksGrouping()) {
+          continue;
+        }
+        Iterator<DestinationTaskInputsProperty> it = iterator();
+        while(it.hasNext()) {
+          sourceVertexInfo.getDestinationInputsProperties().put(
+              destinationIndex,it.next());
+          destinationIndex++;
+        }
+        startNextPartitionsGroup();
+      }
+    }
+  }
+
+  public ReconfigVertexParams computeRouting() {
+    int currentParallelism = pendingTasks.size();
+    int finalTaskParallelism = 0;
+    long[] estimatedPartitionOutputSize = estimatePartitionSize();
+    for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) {
+      FairSourceVertexInfo info = (FairSourceVertexInfo)vInfo.getValue();
+      computeParallelism(estimatedPartitionOutputSize, info);
+      if (finalTaskParallelism != 0) {
+        Preconditions.checkState(
+            finalTaskParallelism == info.getDestinationInputsProperties().size(),
+                "the parallelism shall be the same for source vertices");
+      }
+      finalTaskParallelism = info.getDestinationInputsProperties().size();
+
+      FairEdgeConfiguration fairEdgeConfig = new FairEdgeConfiguration(
+          currentParallelism, info.getDestinationInputsProperties());
+      EdgeManagerPluginDescriptor descriptor =
+          EdgeManagerPluginDescriptor.create(
+              FairShuffleEdgeManager.class.getName());
+      descriptor.setUserPayload(fairEdgeConfig.getBytePayload());
+      vInfo.getValue().newDescriptor = descriptor;
+    }
+    ReconfigVertexParams params = new ReconfigVertexParams(
+        finalTaskParallelism, null);
+
+    return params;
+  }
+
+  @Override
+  void postReconfigVertex() {
+  }
+
+  @Override
+  void processPendingTasks() {
+  }
+
+  private void computeParallelism(long[] estimatedPartitionOutputSize,
+      FairSourceVertexInfo sourceVertexInfo) {
+    PartitionsGroupingCalculator calculator = new PartitionsGroupingCalculator(
+        estimatedPartitionOutputSize, sourceVertexInfo);
+    calculator.compute();
+  }
+
+  @Override
+  List<ScheduleTaskRequest> getTasksToSchedule(
+      TaskAttemptIdentifier completedSourceAttempt) {
+    float minSourceVertexCompletedTaskFraction =
+        getMinSourceVertexCompletedTaskFraction();
+    int numTasksToSchedule = getNumOfTasksToScheduleAndLog(
+        minSourceVertexCompletedTaskFraction);
+    if (numTasksToSchedule > 0) {
+      boolean scheduleAll =
+          (numTasksToSchedule == pendingTasks.size());
+      List<ScheduleTaskRequest> tasksToSchedule =
+          Lists.newArrayListWithCapacity(numTasksToSchedule);
+
+      Iterator<PendingTaskInfo> it = pendingTasks.iterator();
+      FairSourceVertexInfo srcInfo = null;
+      int srcTaskId = 0;
+      if (completedSourceAttempt != null) {
+        srcTaskId = completedSourceAttempt.getTaskIdentifier().getIdentifier();
+        String srcVertexName = completedSourceAttempt.getTaskIdentifier().getVertexIdentifier().getName();
+        srcInfo = (FairSourceVertexInfo)getSourceVertexInfo(srcVertexName);
+      }
+      while (it.hasNext() && numTasksToSchedule > 0) {
+        Integer taskIndex = it.next().getIndex();
+        // filter out those destination tasks that don't depend on
+        // this completed source task.
+        // destinationInputsProperties's size could be 0 if routing computation
+        // is skipped.
+        if (!scheduleAll && config.isAutoParallelismEnabled()
+            && srcInfo != null && srcInfo.getDestinationInputsProperties().size() > 0) {
+          DestinationTaskInputsProperty property =
+              srcInfo.getDestinationInputsProperties().get(taskIndex);
+          if (!property.isSourceTaskInRange(srcTaskId)) {
+            LOG.debug("completedSourceTaskIndex {} and taskIndex {} don't " +
+                "connect.", srcTaskId, taskIndex);
+            continue;
+          }
+        }
+        tasksToSchedule.add(ScheduleTaskRequest.create(taskIndex, null));
+        it.remove();
+        numTasksToSchedule--;
+      }
+      return tasksToSchedule;
+    }
+    return null;
+  }
+
+  static class FairShuffleVertexManagerConfig extends ShuffleVertexManagerBaseConfig {
+    final FairRoutingType fairRoutingType;
+    public FairShuffleVertexManagerConfig(final boolean enableAutoParallelism,
+        final long desiredTaskInputDataSize, final float slowStartMinFraction,
+        final float slowStartMaxFraction, final FairRoutingType fairRoutingType) {
+      super(enableAutoParallelism, desiredTaskInputDataSize,
+          slowStartMinFraction, slowStartMaxFraction);
+      this.fairRoutingType = fairRoutingType;
+      LOG.info("fairRoutingType {}", this.fairRoutingType);
+    }
+    FairRoutingType getFairRoutingType() {
+      return fairRoutingType;
+    }
+  }
+
+  @Override
+  ShuffleVertexManagerBaseConfig initConfiguration() {
+    float slowStartMinFraction = conf.getFloat(
+        TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION,
+        TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
+    FairRoutingType fairRoutingType = FairRoutingType.fromString(
+        conf.get(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+            TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT));
+
+    mgrConfig = new FairShuffleVertexManagerConfig(
+        fairRoutingType.enabled(),
+        conf.getLong(
+            TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
+            TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT),
+        slowStartMinFraction,
+        conf.getFloat(
+            TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION,
+            Math.max(slowStartMinFraction,
+            TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT)),
+        fairRoutingType);
+    return mgrConfig;
+  }
+
+  /**
+   * Create a {@link VertexManagerPluginDescriptor} builder that can be used to
+   * configure the plugin.
+   *
+   * @param conf
+   *          {@link Configuration} May be modified in place. May be null if the
+   *          configuration parameters are to be set only via code. If
+   *          configuration values may be changed at runtime via a config file
+   *          then pass in a {@link Configuration} that is initialized from a
+   *          config file. The parameters that are not overridden in code will
+   *          be derived from the Configuration object.
+   * @return {@link FairShuffleVertexManagerConfigBuilder}
+   */
+  public static FairShuffleVertexManagerConfigBuilder
+      createConfigBuilder(@Nullable Configuration conf) {
+    return new FairShuffleVertexManagerConfigBuilder(conf);
+  }
+
+  /**
+   * Helper class to configure ShuffleVertexManager
+   */
+  public static final class FairShuffleVertexManagerConfigBuilder {
+    private final Configuration conf;
+
+    private FairShuffleVertexManagerConfigBuilder(@Nullable Configuration conf) {
+      if (conf == null) {
+        this.conf = new Configuration(false);
+      } else {
+        this.conf = conf;
+      }
+    }
+
+    public FairShuffleVertexManagerConfigBuilder setAutoParallelism(
+        FairRoutingType fairRoutingType) {
+      conf.set(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
+          fairRoutingType.toString());
+      return this;
+    }
+
+    public FairShuffleVertexManagerConfigBuilder
+        setSlowStartMinSrcCompletionFraction(float minFraction) {
+      conf.setFloat(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION,
+          minFraction);
+      return this;
+    }
+
+    public FairShuffleVertexManagerConfigBuilder
+        setSlowStartMaxSrcCompletionFraction(float maxFraction) {
+      conf.setFloat(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION,
+          maxFraction);
+      return this;
+    }
+
+    public FairShuffleVertexManagerConfigBuilder setDesiredTaskInputSize(
+        long desiredTaskInputSize) {
+      conf.setLong(TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
+          desiredTaskInputSize);
+      return this;
+    }
+
+    public VertexManagerPluginDescriptor build() {
+      VertexManagerPluginDescriptor desc =
+          VertexManagerPluginDescriptor.create(
+              FairShuffleVertexManager.class.getName());
+
+      try {
+        return desc.setUserPayload(TezUtils.createUserPayloadFromConf(
+            this.conf));
+      } catch (IOException e) {
+        throw new TezUncheckedException(e);
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
index 9937bd1..55a6ced 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
@@ -75,7 +75,7 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase {
 
   /**
    * Enables automatic parallelism determination for the vertex. Based on input data
-   * statisitics the parallelism is decreased to a desired level.
+   * statistics the parallelism is decreased to a desired level.
    */
   public static final String TEZ_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL =
       "tez.shuffle-vertex-manager.enable.auto-parallel";
@@ -266,7 +266,6 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase {
           + sourceIndex % partitionRange;
       return EventRouteMetadata.create(1, new int[]{targetIndex});
     }
-    
 
     
     @Override
@@ -447,19 +446,8 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase {
     // Change this to use per partition stats for more accuracy TEZ-2962.
     // Instead of aggregating overall size and then dividing equally - coalesce partitions until
     // desired per partition size is achieved.
-    BigInteger expectedTotalSourceTasksOutputSize = BigInteger.ZERO;
-    for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) {
-      SourceVertexInfo srcInfo = vInfo.getValue();
-      if (srcInfo.numTasks > 0 && srcInfo.numVMEventsReceived > 0) {
-        // this assumes that 1 vmEvent is received per completed task - TEZ-2961
-        // Estimate total size by projecting based on the current average size per event
-        BigInteger srcOutputSize = BigInteger.valueOf(srcInfo.outputSize);
-        BigInteger srcNumTasks = BigInteger.valueOf(srcInfo.numTasks);
-        BigInteger srcNumVMEventsReceived = BigInteger.valueOf(srcInfo.numVMEventsReceived);
-        BigInteger expectedSrcOutputSize = srcOutputSize.multiply(srcNumTasks).divide(srcNumVMEventsReceived);
-        expectedTotalSourceTasksOutputSize = expectedTotalSourceTasksOutputSize.add(expectedSrcOutputSize);
-      }
-    }
+    BigInteger expectedTotalSourceTasksOutputSize =
+        getExpectedTotalBipartiteSourceTasksOutputSize();
 
     LOG.info("Expected output: {} based on actual output: {} from {} vertex " +
         "manager events. desiredTaskInputSize: {} max slow start tasks: {} " +
@@ -527,7 +515,13 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase {
     EdgeManagerPluginDescriptor descriptor =
         EdgeManagerPluginDescriptor.create(CustomShuffleEdgeManager.class.getName());
     descriptor.setUserPayload(edgeManagerConfig.toUserPayload());
-    ReconfigVertexParams params = new ReconfigVertexParams(finalTaskParallelism, null, descriptor);
+
+    Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo();
+    for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
+      entry.getValue().newDescriptor = descriptor;
+    }
+    ReconfigVertexParams params =
+        new ReconfigVertexParams(finalTaskParallelism, null);
     return params;
   }
 
@@ -623,13 +617,14 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase {
         Preconditions.checkState(index < targetIndexes.length,
             "index=" + index +", targetIndexes length=" + targetIndexes.length);
         int[] mapping = targetIndexes[index];
-        long totalStats = 0;
+        int partitionStats = 0;
         for (int i : mapping) {
-          totalStats += stats[i];
+          partitionStats += getCurrentlyKnownStatsAtIndex(i);
         }
-        computedPartitionSizes |= taskInfo.setInputStats(totalStats);
+        computedPartitionSizes |= taskInfo.setInputStats(partitionStats);
       } else {
-        computedPartitionSizes |= taskInfo.setInputStats(stats[index]);
+        computedPartitionSizes |= taskInfo.setInputStats(
+            getCurrentlyKnownStatsAtIndex(index));
       }
     }
     return computedPartitionSizes;
@@ -637,8 +632,6 @@ public class ShuffleVertexManager extends ShuffleVertexManagerBase {
 
 
 
-
-
   /**
    * Create a {@link VertexManagerPluginDescriptor} builder that can be used to
    * configure the plugin.

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
index dc6cd3b..967d0ea 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManagerBase.java
@@ -35,12 +35,12 @@ import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
-import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexLocationHint;
+import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
 import org.apache.tez.dag.api.event.VertexState;
 import org.apache.tez.dag.api.event.VertexStateUpdate;
-import org.apache.tez.dag.api.VertexLocationHint;
 import org.apache.tez.runtime.library.utils.DATA_RANGE_IN_MB;
 import org.roaringbitmap.RoaringBitmap;
 import org.slf4j.Logger;
@@ -58,6 +58,8 @@ import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexMan
 
 import java.io.DataInputStream;
 import java.io.IOException;
+import java.math.BigInteger;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.BitSet;
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -65,13 +67,11 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.zip.Inflater;
 
 /**
- * Starts scheduling tasks when number of completed source tasks crosses
- * <code>slowStartMinFraction</code> and schedules all tasks
- *  when <code>slowStartMaxFraction</code> is reached
+ * It provides common functions used by ShuffleVertexManager and
+ * FairShuffleVertexManager.
  */
 @Private
 @Evolving
@@ -102,7 +102,6 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
   int totalTasksToSchedule = 0;
 
   @VisibleForTesting
-  long[] stats; //approximate amount of data to be fetched
   Configuration conf;
   ShuffleVertexManagerBaseConfig config;
   // requires synchronized access
@@ -132,10 +131,14 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     int numTasks;
     int numVMEventsReceived;
     long outputSize;
+    int[] statsInMB;
+    EdgeManagerPluginDescriptor newDescriptor;
 
-    SourceVertexInfo(final EdgeProperty edgeProperty) {
+    SourceVertexInfo(final EdgeProperty edgeProperty,
+       int totalTasksToSchedule) {
       this.edgeProperty = edgeProperty;
       this.finishedTaskSet = new BitSet();
+      this.statsInMB = new int[totalTasksToSchedule];
     }
 
     int getNumTasks() {
@@ -145,11 +148,20 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     int getNumCompletedTasks() {
       return finishedTaskSet.cardinality();
     }
+    int getExpectedStatsInMBAtIndex(int index) {
+      return (numVMEventsReceived == 0) ?
+          0: statsInMB[index] * numTasks / numVMEventsReceived;
+    }
+  }
+
+  SourceVertexInfo createSourceVertexInfo(EdgeProperty edgeProperty,
+      int numTasks) {
+    return new SourceVertexInfo(edgeProperty, numTasks);
   }
 
   static class PendingTaskInfo {
     final private int index;
-    private long inputStats;
+    private int inputStats;
 
     public PendingTaskInfo(int index) {
       this.index = index;
@@ -161,11 +173,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     public int getIndex() {
       return index;
     }
-    public long getInputStats() {
+    public int getInputStats() {
       return inputStats;
     }
     // return true if stat is set.
-    public boolean setInputStats(long inputStats) {
+    public boolean setInputStats(int inputStats) {
       if (inputStats > 0 && this.inputStats != inputStats) {
         this.inputStats = inputStats;
         return true;
@@ -178,14 +190,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
   static class ReconfigVertexParams {
     final private int finalParallelism;
     final private VertexLocationHint locationHint;
-    final private EdgeManagerPluginDescriptor descriptor;
 
     public ReconfigVertexParams(final int finalParallelism,
-        final VertexLocationHint locationHint,
-        final EdgeManagerPluginDescriptor descriptor) {
+        final VertexLocationHint locationHint) {
       this.finalParallelism = finalParallelism;
       this.locationHint = locationHint;
-      this.descriptor = descriptor;
     }
 
     public int getFinalParallelism() {
@@ -194,9 +203,6 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     public VertexLocationHint getLocationHint() {
       return locationHint;
     }
-    public EdgeManagerPluginDescriptor getDescriptor() {
-      return descriptor;
-    }
   }
 
   public ShuffleVertexManagerBase(VertexManagerPluginContext context) {
@@ -209,7 +215,8 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     // examine edges after vertex started because until then these may not have been defined
     Map<String, EdgeProperty> inputs = getContext().getInputVertexEdgeProperties();
     for(Map.Entry<String, EdgeProperty> entry : inputs.entrySet()) {
-      srcVertexInfo.put(entry.getKey(), new SourceVertexInfo(entry.getValue()));
+      srcVertexInfo.put(entry.getKey(), createSourceVertexInfo(entry.getValue(),
+          getContext().getVertexNumTasks(getContext().getVertexName())));
       // TODO what if derived class has already called this
       // register for status update from all source vertices
       getContext().registerForVertexStateUpdates(entry.getKey(),
@@ -218,9 +225,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
         bipartiteSources++;
       }
     }
-    if(bipartiteSources == 0) {
-      throw new TezUncheckedException("Atleast 1 bipartite source should exist");
-    }
+    onVertexStartedCheck();
 
     for (VertexStateUpdate stateUpdate : pendingStateUpdates) {
       handleVertexStateUpdate(stateUpdate);
@@ -249,6 +254,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     processPendingTasks(null);
   }
 
+  protected void onVertexStartedCheck() {
+    if(bipartiteSources == 0) {
+      throw new TezUncheckedException("At least 1 bipartite source should exist");
+    }
+  }
 
   @Override
   public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) {
@@ -274,8 +284,10 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
   }
 
   @VisibleForTesting
-  void parsePartitionStats(RoaringBitmap partitionStats) {
-    Preconditions.checkState(stats != null, "Stats should be initialized");
+  void parsePartitionStats(SourceVertexInfo srcInfo,
+      RoaringBitmap partitionStats) {
+    Preconditions.checkState(srcInfo.statsInMB != null,
+        "Stats should be initialized");
     Iterator<Integer> it = partitionStats.iterator();
     final DATA_RANGE_IN_MB[] RANGES = DATA_RANGE_IN_MB.values();
     final int RANGE_LEN = RANGES.length;
@@ -285,14 +297,15 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
       int rangeIndex = ((pos) % RANGE_LEN);
       //Add to aggregated stats and normalize to DATA_RANGE_IN_MB.
       if (RANGES[rangeIndex].getSizeInMB() > 0) {
-        stats[index] += RANGES[rangeIndex].getSizeInMB();
+        srcInfo.statsInMB[index] += RANGES[rangeIndex].getSizeInMB();
       }
     }
   }
 
-  void parseDetailedPartitionStats(List<Integer> partitionStats) {
+  void parseDetailedPartitionStats(SourceVertexInfo srcInfo,
+      List<Integer> partitionStats) {
     for (int i=0; i<partitionStats.size(); i++) {
-      stats[i] += partitionStats.get(i);
+      srcInfo.statsInMB[i] += partitionStats.get(i);
     }
   }
 
@@ -344,7 +357,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
           NonSyncByteArrayInputStream bin = new NonSyncByteArrayInputStream(rawData);
           partitionStats.deserialize(new DataInputStream(bin));
 
-          parsePartitionStats(partitionStats);
+          parsePartitionStats(srcInfo, partitionStats);
 
         } catch (IOException e) {
           throw new TezUncheckedException(e);
@@ -352,7 +365,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
       } else if (proto.hasDetailedPartitionStats()) {
         List<Integer> detailedPartitionStats =
             proto.getDetailedPartitionStats().getSizeInMbList();
-        parseDetailedPartitionStats(detailedPartitionStats);
+        parseDetailedPartitionStats(srcInfo, detailedPartitionStats);
       }
       srcInfo.numVMEventsReceived++;
       srcInfo.outputSize += sourceTaskOutputSize;
@@ -361,11 +374,11 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
 
     if (LOG.isDebugEnabled()) {
       LOG.debug("For attempt: {} received info of output size: {}"
-          + " vertex numEventsReceived: {} vertex output size: {}"
-          + " total numEventsReceived: {} total output size: {}",
-          vmEvent.getProducerAttemptIdentifier(), sourceTaskOutputSize,
-          srcInfo.numVMEventsReceived, srcInfo.outputSize,
-          numVertexManagerEventsReceived, completedSourceTasksOutputSize);
+                      + " vertex numEventsReceived: {} vertex output size: {}"
+                      + " total numEventsReceived: {} total output size: {}",
+              vmEvent.getProducerAttemptIdentifier(), sourceTaskOutputSize,
+              srcInfo.numVMEventsReceived, srcInfo.outputSize,
+              numVertexManagerEventsReceived, completedSourceTasksOutputSize);
     }
   }
 
@@ -379,9 +392,6 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
       pendingTasks.add(new PendingTaskInfo(i));
     }
     totalTasksToSchedule = pendingTasks.size();
-    if (stats == null) {
-      stats = new long[totalTasksToSchedule]; // TODO lost previous data
-    }
   }
 
   /**
@@ -427,6 +437,41 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     }
   }
 
+  BigInteger getExpectedTotalBipartiteSourceTasksOutputSize() {
+    BigInteger expectedTotalSourceTasksOutputSize = BigInteger.ZERO;
+    for (Map.Entry<String, SourceVertexInfo> vInfo : getBipartiteInfo()) {
+      SourceVertexInfo srcInfo = vInfo.getValue();
+      if (srcInfo.numTasks > 0 && srcInfo.numVMEventsReceived > 0) {
+        // this assumes that 1 vmEvent is received per completed task - TEZ-2961
+        // Estimate total size by projecting based on the current average size per event
+        BigInteger srcOutputSize = BigInteger.valueOf(srcInfo.outputSize);
+        BigInteger srcNumTasks = BigInteger.valueOf(srcInfo.numTasks);
+        BigInteger srcNumVMEventsReceived = BigInteger.valueOf(srcInfo.numVMEventsReceived);
+        BigInteger expectedSrcOutputSize = srcOutputSize.multiply(
+            srcNumTasks).divide(srcNumVMEventsReceived);
+        expectedTotalSourceTasksOutputSize =
+            expectedTotalSourceTasksOutputSize.add(expectedSrcOutputSize);
+      }
+    }
+    return expectedTotalSourceTasksOutputSize;
+  }
+
+  int getCurrentlyKnownStatsAtIndex(int index) {
+    int stats = 0;
+    for(SourceVertexInfo entry : getAllSourceVertexInfo()) {
+      stats += entry.statsInMB[index];
+    }
+    return stats;
+  }
+
+  int getExpectedStatsAtIndex(int index) {
+    int stats = 0;
+    for(SourceVertexInfo entry : getAllSourceVertexInfo()) {
+      stats += entry.getExpectedStatsInMBAtIndex(index);
+    }
+    return stats;
+  }
+
   /**
    * Subclass might return null to indicate there is no new routing.
    */
@@ -447,7 +492,7 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
       if (computeRoutingAction.equals(computeRoutingAction.COMPUTE)) {
         ReconfigVertexParams params = computeRouting();
         if (params != null) {
-          reconfigVertex(params.getFinalParallelism(), params.getDescriptor());
+          reconfigVertex(params.getFinalParallelism());
           updatePendingTasks();
           postReconfigVertex();
         }
@@ -489,6 +534,14 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     }
   }
 
+  Iterable<SourceVertexInfo> getAllSourceVertexInfo() {
+    return srcVertexInfo.values();
+  }
+
+  SourceVertexInfo getSourceVertexInfo(String vertextName) {
+    return srcVertexInfo.get(vertextName);
+  }
+
   Iterable<Map.Entry<String, SourceVertexInfo>> getBipartiteInfo() {
     return Iterables.filter(srcVertexInfo.entrySet(),
         new Predicate<Map.Entry<String,SourceVertexInfo>>() {
@@ -753,20 +806,18 @@ abstract class ShuffleVertexManagerBase extends VertexManagerPlugin {
     // Not allowing this for now. Nothing to do.
   }
 
-  private void reconfigVertex(final int finalTaskParallelism,
-      final EdgeManagerPluginDescriptor edgeManagerDescriptor) {
+  private void reconfigVertex(final int finalTaskParallelism) {
     Map<String, EdgeProperty> edgeProperties =
         new HashMap<String, EdgeProperty>(bipartiteSources);
     Iterable<Map.Entry<String, SourceVertexInfo>> bipartiteItr = getBipartiteInfo();
     for(Map.Entry<String, SourceVertexInfo> entry : bipartiteItr) {
       String vertex = entry.getKey();
       EdgeProperty oldEdgeProp = entry.getValue().edgeProperty;
-      EdgeProperty newEdgeProp = EdgeProperty.create(edgeManagerDescriptor,
+      EdgeProperty newEdgeProp = EdgeProperty.create(entry.getValue().newDescriptor,
           oldEdgeProp.getDataSourceType(), oldEdgeProp.getSchedulingType(),
           oldEdgeProp.getEdgeSource(), oldEdgeProp.getEdgeDestination());
       edgeProperties.put(vertex, newEdgeProp);
     }
-
     getContext().reconfigureVertex(finalTaskParallelism, null, edgeProperties);
   }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/main/proto/FairShufflePayloads.proto
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/proto/FairShufflePayloads.proto b/tez-runtime-library/src/main/proto/FairShufflePayloads.proto
new file mode 100644
index 0000000..334cbc9
--- /dev/null
+++ b/tez-runtime-library/src/main/proto/FairShufflePayloads.proto
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+option java_package = "org.apache.tez.dag.library.vertexmanager";
+option java_outer_classname = "FairShuffleUserPayloads";
+option java_generate_equals_and_hash = true;
+
+message RangeProto {
+  optional int32 first_index = 1;
+  optional int32 num_of_indexes = 2;
+}
+
+message FairShuffleEdgeManagerDestinationTaskPropProto {
+  optional int32 destination_task_index = 1;
+  optional RangeProto partitions = 2;
+  optional RangeProto source_tasks = 3;
+}
+
+message FairShuffleEdgeManagerConfigPayloadProto {
+  optional int32 num_buckets = 1;
+  repeated FairShuffleEdgeManagerDestinationTaskPropProto destinationTaskProps = 2;
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/2c4ef9fe/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java
new file mode 100644
index 0000000..9c94c14
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestFairShuffleVertexManager.java
@@ -0,0 +1,347 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.library.vertexmanager;
+
+import com.google.common.collect.Lists;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.dag.api.EdgeManagerPlugin;
+import org.apache.tez.dag.api.EdgeManagerPluginOnDemand;
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
+import org.apache.tez.dag.api.InputDescriptor;
+import org.apache.tez.dag.api.OutputDescriptor;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.dag.api.VertexLocationHint;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.dag.library.vertexmanager.FairShuffleVertexManager.FairRoutingType;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.anyList;
+import static org.mockito.Mockito.anyMap;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@SuppressWarnings({ "unchecked", "rawtypes" })
+public class TestFairShuffleVertexManager
+    extends TestShuffleVertexManagerUtils {
+  List<TaskAttemptIdentifier> emptyCompletions = null;
+
+  @Test(timeout = 5000)
+  public void testAutoParallelismConfig() throws Exception {
+    FairShuffleVertexManager manager;
+
+    final List<Integer> scheduledTasks = Lists.newLinkedList();
+
+    final VertexManagerPluginContext mockContext = createVertexManagerContext(
+        "Vertex1", 2, "Vertex2", 2, "Vertex3", 2,
+            "Vertex4", 4, scheduledTasks, null);
+
+    manager = createManager(null, mockContext, null, 0.5f);
+    verify(mockContext, times(1)).vertexReconfigurationPlanned(); // Tez notified of reconfig
+    Assert.assertTrue(manager.config.isAutoParallelismEnabled());
+    Assert.assertTrue(manager.config.getDesiredTaskInputDataSize() == 1000l * MB);
+    Assert.assertTrue(manager.config.getMinFraction() == 0.25f);
+    Assert.assertTrue(manager.config.getMaxFraction() == 0.5f);
+
+    manager = createManager(null, mockContext, null, null, null, null);
+    verify(mockContext, times(1)).vertexReconfigurationPlanned(); // Tez not notified of reconfig
+
+    Assert.assertTrue(!manager.config.isAutoParallelismEnabled());
+    Assert.assertTrue(manager.config.getDesiredTaskInputDataSize() ==
+        FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT);
+    Assert.assertTrue(manager.config.getMinFraction() ==
+        FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
+    Assert.assertTrue(manager.config.getMaxFraction() ==
+        FairShuffleVertexManager.TEZ_FAIR_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT);
+  }
+
+  @Test(timeout = 5000)
+  public void testInvalidSetup() {
+    Configuration conf = new Configuration();
+    ShuffleVertexManagerBase manager;
+
+    final List<Integer> scheduledTasks = Lists.newLinkedList();
+
+    final VertexManagerPluginContext mockContext = createVertexManagerContext(
+        "Vertex1", 2, "Vertex2", 2, "Vertex3", 2,
+        "Vertex4", 4, scheduledTasks, null);
+
+    // fail if there are more than one bipartite for FAIR_PARALLELISM
+    try {
+      manager = createFairShuffleVertexManager(conf, mockContext,
+          FairRoutingType.FAIR_PARALLELISM, 1000 * MB, 0.001f, 0.001f);
+      manager.onVertexStarted(emptyCompletions);
+      Assert.assertFalse(true);
+    } catch (TezUncheckedException e) {
+      Assert.assertTrue(e.getMessage().contains(
+          "Having more than one destination task process same partition(s) " +
+              "only works with one bipartite source."));
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testReduceSchedulingWithPartitionStats() throws Exception {
+    final Map<String, EdgeManagerPlugin> newEdgeManagers =
+        new HashMap<String, EdgeManagerPlugin>();
+    testSchedulingWithPartitionStats(FairRoutingType.REDUCE_PARALLELISM,
+        2, 2, newEdgeManagers);
+    EdgeManagerPluginOnDemand edgeManager =
+        (EdgeManagerPluginOnDemand)newEdgeManagers.values().iterator().next();
+
+    // The first destination task fetches two partitions from all source tasks.
+    // 6 == 3 source tasks * 2 merged partitions
+    Assert.assertEquals(6, edgeManager.getNumDestinationTaskPhysicalInputs(0));
+    EdgeManagerPluginOnDemand.EventRouteMetadata routeMetadata;
+    for (int sourceTaskIndex = 0; sourceTaskIndex < 3; sourceTaskIndex++) {
+      for (int j = 0; j < 2; j++) {
+        routeMetadata = (j == 0) ?
+            edgeManager.routeCompositeDataMovementEventToDestination(
+                sourceTaskIndex, 0) :
+            edgeManager.routeInputSourceTaskFailedEventToDestination(
+                sourceTaskIndex, 0);
+        Assert.assertEquals(2, routeMetadata.getNumEvents());
+        if (j == 0) {
+          Assert.assertArrayEquals(new int[]{0, 1},
+              routeMetadata.getSourceIndices());
+        }
+        Assert.assertArrayEquals(
+            new int[]{0 + sourceTaskIndex * 2, 1 + sourceTaskIndex * 2},
+            routeMetadata.getTargetIndices());
+      }
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testFairSchedulingWithPartitionStats() throws Exception {
+    final Map<String, EdgeManagerPlugin> newEdgeManagers =
+        new HashMap<String, EdgeManagerPlugin>();
+    testSchedulingWithPartitionStats(FairRoutingType.FAIR_PARALLELISM,
+        3, 2, newEdgeManagers);
+
+    // Get the first edgeManager which is SCATTER_GATHER.
+    EdgeManagerPluginOnDemand edgeManager =
+        (EdgeManagerPluginOnDemand)newEdgeManagers.values().iterator().next();
+
+    // The first destination task fetches two partitions from all source tasks.
+    // 6 == 3 source tasks * 2 merged partitions
+    Assert.assertEquals(6, edgeManager.getNumDestinationTaskPhysicalInputs(0));
+    EdgeManagerPluginOnDemand.EventRouteMetadata routeMetadata;
+    for (int sourceTaskIndex = 0; sourceTaskIndex < 3; sourceTaskIndex++) {
+      for (int j = 0; j < 2; j++) {
+        routeMetadata = (j == 0) ?
+            edgeManager.routeCompositeDataMovementEventToDestination(
+                sourceTaskIndex, 0) :
+            edgeManager.routeInputSourceTaskFailedEventToDestination(
+                sourceTaskIndex, 0);
+        Assert.assertEquals(2, routeMetadata.getNumEvents());
+        if (j == 0) {
+          Assert.assertArrayEquals(new int[]{0, 1},
+              routeMetadata.getSourceIndices());
+        }
+        Assert.assertArrayEquals(
+            new int[]{0 + sourceTaskIndex * 2, 1 + sourceTaskIndex * 2},
+            routeMetadata.getTargetIndices());
+      }
+    }
+
+    // The 2nd destination task fetches one partition from the first source
+    // task.
+    Assert.assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    for (int j = 0; j < 2; j++) {
+      routeMetadata = (j == 0) ?
+          edgeManager.routeCompositeDataMovementEventToDestination(
+              0, 1) :
+          edgeManager.routeInputSourceTaskFailedEventToDestination(
+              0, 1);
+      Assert.assertEquals(1, routeMetadata.getNumEvents());
+      if (j == 0) {
+        Assert.assertEquals(2, routeMetadata.getSourceIndices()[0]);
+      }
+      Assert.assertEquals(0, routeMetadata.getTargetIndices()[0]);
+    }
+
+    // The 3rd destination task fetches one partition from the 2nd and 3rd
+    // source task.
+    Assert.assertEquals(2, edgeManager.getNumDestinationTaskPhysicalInputs(2));
+    for (int sourceTaskIndex = 1; sourceTaskIndex < 3; sourceTaskIndex++) {
+      for (int j = 0; j < 2; j++) {
+        routeMetadata = (j == 0) ?
+            edgeManager.routeCompositeDataMovementEventToDestination(
+                sourceTaskIndex, 2) :
+            edgeManager.routeInputSourceTaskFailedEventToDestination(
+                sourceTaskIndex, 2);
+        Assert.assertEquals(1, routeMetadata.getNumEvents());
+        if (j == 0) {
+          Assert.assertEquals(2, routeMetadata.getSourceIndices()[0]);
+        }
+        Assert.assertEquals(sourceTaskIndex - 1,
+            routeMetadata.getTargetIndices()[0]);
+      }
+    }
+  }
+
+  // Create a DAG with one destination vertexes connected to 3 source vertexes.
+  // There are 3 tasks for each vertex. One edge is of type SCATTER_GATHER.
+  // The other edges are BROADCAST.
+  private void testSchedulingWithPartitionStats(
+      FairRoutingType fairRoutingType, int expectedScheduledTasks,
+      int expectedNumDestinationConsumerTasks,
+      Map<String, EdgeManagerPlugin> newEdgeManagers)
+      throws Exception {
+    Configuration conf = new Configuration();
+    FairShuffleVertexManager manager;
+
+    HashMap<String, EdgeProperty> mockInputVertices = new HashMap<String, EdgeProperty>();
+    String r1 = "R1";
+    final int numOfTasksInr1 = 3;
+    EdgeProperty eProp1 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.SCATTER_GATHER,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String m2 = "M2";
+    final int numOfTasksInM2 = 3;
+    EdgeProperty eProp2 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.BROADCAST,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+    String m3 = "M3";
+    final int numOfTasksInM3 = 3;
+    EdgeProperty eProp3 = EdgeProperty.create(
+        EdgeProperty.DataMovementType.BROADCAST,
+        EdgeProperty.DataSourceType.PERSISTED,
+        SchedulingType.SEQUENTIAL,
+        OutputDescriptor.create("out"),
+        InputDescriptor.create("in"));
+
+    final String mockManagedVertexId = "R2";
+    final int numOfTasksInDestination = 3;
+
+    mockInputVertices.put(r1, eProp1);
+    mockInputVertices.put(m2, eProp2);
+    mockInputVertices.put(m3, eProp3);
+
+    final VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(mockInputVertices);
+    when(mockContext.getVertexName()).thenReturn(mockManagedVertexId);
+    when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(numOfTasksInDestination);
+    when(mockContext.getVertexNumTasks(r1)).thenReturn(numOfTasksInr1);
+    when(mockContext.getVertexNumTasks(m2)).thenReturn(numOfTasksInM2);
+    when(mockContext.getVertexNumTasks(m3)).thenReturn(numOfTasksInM3);
+
+    final List<Integer> scheduledTasks = Lists.newLinkedList();
+    doAnswer(new ScheduledTasksAnswer(scheduledTasks)).when(
+        mockContext).scheduleTasks(anyList());
+
+    doAnswer(new reconfigVertexAnswer(mockContext, mockManagedVertexId,
+        newEdgeManagers)).when(mockContext).reconfigureVertex(
+        anyInt(), any(VertexLocationHint.class), anyMap());
+
+    // check initialization
+    manager = createFairShuffleVertexManager(conf, mockContext,
+        fairRoutingType, 1000 * MB, 0.001f, 0.001f);
+    manager.onVertexStarted(emptyCompletions);
+    Assert.assertTrue(manager.bipartiteSources == 1);
+
+    manager.onVertexStateUpdated(new VertexStateUpdate(r1,
+        VertexState.CONFIGURED));
+    manager.onVertexStateUpdated(new VertexStateUpdate(m2,
+        VertexState.CONFIGURED));
+
+    Assert.assertEquals(numOfTasksInDestination,
+        manager.pendingTasks.size()); // no tasks scheduled
+    Assert.assertEquals(numOfTasksInr1,
+        manager.totalNumBipartiteSourceTasks);
+    Assert.assertEquals(0, manager.numBipartiteSourceTasksCompleted);
+
+    //Send an event for r1.
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(r1, 0));
+    Assert.assertTrue(manager.pendingTasks.size() == numOfTasksInDestination); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == numOfTasksInr1);
+
+    long[] sizes = new long[]{(50 * MB), (200 * MB), (500 * MB)};
+    VertexManagerEvent vmEvent = getVertexManagerEvent(sizes, 800 * MB,
+        r1, true);
+    manager.onVertexManagerEventReceived(vmEvent); //send VM event
+
+    //stats from another task
+    sizes = new long[]{(60 * MB), (300 * MB), (600 * MB)};
+    vmEvent = getVertexManagerEvent(sizes, 1200 * MB, r1, true);
+    manager.onVertexManagerEventReceived(vmEvent); //send VM event
+
+    //Send an event for m2.
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(m2, 0));
+    Assert.assertTrue(manager.pendingTasks.size() == numOfTasksInDestination); // no tasks scheduled
+    Assert.assertTrue(manager.totalNumBipartiteSourceTasks == numOfTasksInr1);
+
+    //Send an event for m3.
+    manager.onVertexStateUpdated(new VertexStateUpdate(m3, VertexState.CONFIGURED));
+    manager.onSourceTaskCompleted(createTaskAttemptIdentifier(m3, 0));
+    Assert.assertTrue(manager.pendingTasks.size() == 0); // all tasks scheduled
+    Assert.assertTrue(scheduledTasks.size() == expectedScheduledTasks);
+
+    Assert.assertEquals(1, newEdgeManagers.size());
+    EdgeManagerPluginOnDemand edgeManager =
+        (EdgeManagerPluginOnDemand)newEdgeManagers.values().iterator().next();
+    // For each source task, there are 3 outputs,
+    // the same as original number of partitions.
+    for (int i = 0; i < numOfTasksInr1; i++) {
+      Assert.assertEquals(numOfTasksInDestination,
+          edgeManager.getNumSourceTaskPhysicalOutputs(0));
+    }
+
+    for (int sourceTaskIndex = 0; sourceTaskIndex < numOfTasksInr1;
+        sourceTaskIndex++) {
+      Assert.assertEquals(expectedNumDestinationConsumerTasks,
+          edgeManager.getNumDestinationConsumerTasks(sourceTaskIndex));
+    }
+  }
+
+  private static FairShuffleVertexManager createManager(Configuration conf,
+      VertexManagerPluginContext context, Float min, Float max) {
+    return createManager(conf, context, true, 1000l * MB, min, max);
+  }
+
+  private static FairShuffleVertexManager createManager(Configuration conf,
+      VertexManagerPluginContext context,
+      Boolean enableAutoParallelism, Long desiredTaskInputSize, Float min,
+      Float max) {
+    return (FairShuffleVertexManager)TestShuffleVertexManagerBase.createManager(
+        FairShuffleVertexManager.class, conf, context, enableAutoParallelism,
+            desiredTaskInputSize, min, max);
+  }
+}


Mime
View raw message