tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From min...@apache.org
Subject tez git commit: TEZ-3458. Auto grouping for cartesian product edge(unpartitioned case). (Zhiyuan Yang via mingma)
Date Wed, 18 Jan 2017 06:41:55 GMT
Repository: tez
Updated Branches:
  refs/heads/master abb350c0c -> 506c9bc3d


TEZ-3458. Auto grouping for cartesian product edge(unpartitioned case). (Zhiyuan Yang via mingma)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/506c9bc3
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/506c9bc3
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/506c9bc3

Branch: refs/heads/master
Commit: 506c9bc3df0c94d94f2ae024ae605faa74e5ba41
Parents: abb350c
Author: Ming Ma <mingma@twitter.com>
Authored: Tue Jan 17 22:41:31 2017 -0800
Committer: Ming Ma <mingma@twitter.com>
Committed: Tue Jan 17 22:41:31 2017 -0800

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../CartesianProductCombination.java            |   3 +
 .../CartesianProductConfig.java                 |  10 ++
 .../CartesianProductEdgeManagerConfig.java      |  12 +-
 ...artesianProductEdgeManagerUnpartitioned.java |  60 +++++---
 .../CartesianProductVertexManager.java          |  26 ++++
 .../CartesianProductVertexManagerConfig.java    |  20 ++-
 ...artesianProductVertexManagerPartitioned.java |   4 +
 .../CartesianProductVertexManagerReal.java      |   3 +-
 ...tesianProductVertexManagerUnpartitioned.java | 117 +++++++++++++--
 .../tez/runtime/library/utils/Grouper.java      |  89 ++++++++++++
 .../main/proto/CartesianProductPayload.proto    |   5 +-
 .../TestCartesianProductCombination.java        |   9 ++
 .../TestCartesianProductConfig.java             |  34 ++++-
 .../TestCartesianProductEdgeManager.java        |  13 +-
 .../TestCartesianProductEdgeManagerConfig.java  |  50 +++++++
 ...tCartesianProductEdgeManagerPartitioned.java |   6 +-
 ...artesianProductEdgeManagerUnpartitioned.java | 125 ++++++++++++----
 ...TestCartesianProductVertexManagerConfig.java |  53 +++++++
 ...artesianProductVertexManagerPartitioned.java |   9 +-
 ...tesianProductVertexManagerUnpartitioned.java | 143 ++++++++++++++++++-
 .../library/cartesianproduct/TestGrouper.java   |  80 +++++++++++
 22 files changed, 788 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index cfc5214..6538006 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -7,6 +7,7 @@ INCOMPATIBLE CHANGES
 
 ALL CHANGES:
 
+  TEZ-3458. Auto grouping for cartesian product edge(unpartitioned case).
   TEZ-3574. Container reuse won't pickup extra dag level local resource.
   TEZ-3443. Remove a repeated/unused method from MRTask.
   TEZ-3551. FrameworkClient created twice causing minor delay.

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductCombination.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductCombination.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductCombination.java
index a46993d..c6c95f2 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductCombination.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductCombination.java
@@ -18,6 +18,7 @@
 package org.apache.tez.runtime.library.cartesianproduct;
 
 import com.google.common.base.Preconditions;
+import com.google.common.primitives.Ints;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -54,6 +55,8 @@ class CartesianProductCombination {
   private final Integer[] factor;
 
   public CartesianProductCombination(int[] numPartitionOrTask) {
+    Preconditions.checkArgument(!Ints.contains(numPartitionOrTask, 0),
+      "CartesianProductCombination doesn't allow zero partition or task");
     this.numPartitionOrTask = Arrays.copyOf(numPartitionOrTask, numPartitionOrTask.length);
     combination = new Integer[numPartitionOrTask.length];
     factor = new Integer[numPartitionOrTask.length];

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
index a7a3940..b57ed84 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
@@ -212,6 +212,16 @@ public class CartesianProductConfig {
       builder.setMaxFraction(conf.getFloat(
         CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION,
         CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT));
+      String enableAutoGrouping =
+        conf.get(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING);
+      if (enableAutoGrouping != null) {
+        builder.setEnableAutoGrouping(Boolean.parseBoolean(enableAutoGrouping));
+      }
+      String desiredBytesPerGroup =
+        conf.get(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP);
+      if (desiredBytesPerGroup != null) {
+        builder.setDesiredBytesPerGroup(Long.parseLong(desiredBytesPerGroup));
+      }
     }
     Preconditions.checkArgument(builder.getMinFraction() <= builder.getMaxFraction(),
       "min fraction(" + builder.getMinFraction() + ") should be less than max fraction(" +

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerConfig.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerConfig.java
index d48a0bb..0347f67 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerConfig.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerConfig.java
@@ -28,18 +28,24 @@ import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUs
 
 class CartesianProductEdgeManagerConfig extends CartesianProductConfig {
   private final int[] numTasks;
+  private final int[] numGroups;
 
   protected CartesianProductEdgeManagerConfig(boolean isPartitioned, String[] sourceVertices,
-                                            int[] numPartitions, int[] numTasks,
+                                            int[] numPartitions, int[] numTasks, int[] numGroups,
                                             CartesianProductFilterDescriptor filterDescriptor) {
     super(isPartitioned, numPartitions, sourceVertices, filterDescriptor);
     this.numTasks = numTasks;
+    this.numGroups = numGroups;
   }
 
   public int[] getNumTasks() {
     return this.numTasks;
   }
 
+  public int[] getNumGroups() {
+    return this.numGroups;
+  }
+
   public static CartesianProductEdgeManagerConfig fromUserPayload(UserPayload payload)
     throws InvalidProtocolBufferException {
     CartesianProductConfigProto proto =
@@ -58,7 +64,9 @@ class CartesianProductEdgeManagerConfig extends CartesianProductConfig {
     }
     int[] numTasks =
       proto.getNumTasksCount() == 0 ? null : Ints.toArray(proto.getNumTasksList());
+    int[] numGroups =
+      proto.getNumGroupsCount() == 0 ? null : Ints.toArray(proto.getNumGroupsList());
     return new CartesianProductEdgeManagerConfig(isPartitioned, sourceVertices, numPartitions,
-      numTasks, filterDescriptor);
+      numTasks, numGroups, filterDescriptor);
   }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerUnpartitioned.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerUnpartitioned.java
index 9e46e95..b9cb155 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerUnpartitioned.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductEdgeManagerUnpartitioned.java
@@ -18,17 +18,18 @@
 package org.apache.tez.runtime.library.cartesianproduct;
 
 import org.apache.tez.dag.api.EdgeManagerPluginContext;
-import org.apache.tez.dag.api.EdgeManagerPluginOnDemand;
+import org.apache.tez.dag.api.EdgeManagerPluginOnDemand.CompositeEventRouteMetadata;
+import org.apache.tez.dag.api.EdgeManagerPluginOnDemand.EventRouteMetadata;
+import org.apache.tez.runtime.library.utils.Grouper;
 
 import javax.annotation.Nullable;
-import java.util.Arrays;
 
-import static org.apache.tez.dag.api.EdgeManagerPluginOnDemand.*;
 
 class CartesianProductEdgeManagerUnpartitioned extends CartesianProductEdgeManagerReal {
   private int positionId;
-  private int[] numTasks;
+  private int[] numGroups;
   private int numDestinationConsumerTasks;
+  private Grouper grouper = new Grouper();
 
   public CartesianProductEdgeManagerUnpartitioned(EdgeManagerPluginContext context) {
     super(context);
@@ -36,39 +37,48 @@ class CartesianProductEdgeManagerUnpartitioned extends CartesianProductEdgeManag
 
   public void initialize(CartesianProductEdgeManagerConfig config) {
     positionId = config.getSourceVertices().indexOf(getContext().getSourceVertexName());
-    this.numTasks = config.getNumTasks();
+    this.numGroups = config.getNumGroups();
 
-    if (numTasks != null && numTasks[positionId] != 0) {
+    if (numGroups != null && numGroups[positionId] != 0) {
+      grouper.init(config.getNumTasks()[positionId], numGroups[positionId]);
       numDestinationConsumerTasks = 1;
-      for (int numTask : numTasks) {
-        numDestinationConsumerTasks *= numTask;
+      for (int numGroup : numGroups) {
+        numDestinationConsumerTasks *= numGroup;
       }
-      numDestinationConsumerTasks /= numTasks[positionId];
+      numDestinationConsumerTasks /= numGroups[positionId];
     }
   }
 
   @Override
   public int routeInputErrorEventToSource(int destTaskId, int failedInputId) throws Exception {
-    return
-      CartesianProductCombination.fromTaskId(numTasks, destTaskId).getCombination().get(positionId);
+    return failedInputId + grouper.getFirstTaskInGroup(
+      CartesianProductCombination.fromTaskId(numGroups, destTaskId).getCombination().get(positionId));
   }
 
   @Override
   public EventRouteMetadata routeDataMovementEventToDestination(int srcTaskId, int srcOutputId,
                                                                 int destTaskId) throws Exception {
-    int index = CartesianProductCombination.fromTaskId(numTasks, destTaskId)
-      .getCombination().get(positionId);
-    return index == srcTaskId ? EventRouteMetadata.create(1, new int[]{0}) : null;
+    int groupId =
+      CartesianProductCombination.fromTaskId(numGroups, destTaskId).getCombination().get(positionId);
+    if (grouper.isInGroup(srcTaskId, groupId)) {
+      int idx = srcTaskId - grouper.getFirstTaskInGroup(groupId);
+      return EventRouteMetadata.create(1, new int[] {idx});
+    }
+    return null;
   }
 
   @Nullable
   @Override
   public CompositeEventRouteMetadata routeCompositeDataMovementEventToDestination(int srcTaskId,
-                                                                         int destTaskId)
+                                                                                  int destTaskId)
     throws Exception {
-    int index = CartesianProductCombination.fromTaskId(numTasks, destTaskId)
-        .getCombination().get(positionId);
-    return index == srcTaskId ? CompositeEventRouteMetadata.create(1, 0, 0) : null;
+    int groupId =
+      CartesianProductCombination.fromTaskId(numGroups, destTaskId).getCombination().get(positionId);
+    if (grouper.isInGroup(srcTaskId, groupId)) {
+      int idx = srcTaskId - grouper.getFirstTaskInGroup(groupId);
+      return CompositeEventRouteMetadata.create(1, idx, 0);
+    }
+    return null;
   }
 
   @Nullable
@@ -76,14 +86,20 @@ class CartesianProductEdgeManagerUnpartitioned extends CartesianProductEdgeManag
   public EventRouteMetadata routeInputSourceTaskFailedEventToDestination(int srcTaskId,
                                                                          int destTaskId)
     throws Exception {
-    int index = CartesianProductCombination.fromTaskId(numTasks, destTaskId)
-      .getCombination().get(positionId);
-    return index == srcTaskId ? EventRouteMetadata.create(1, new int[]{0}) : null;
+    int groupId =
+      CartesianProductCombination.fromTaskId(numGroups, destTaskId).getCombination().get(positionId);
+    if (grouper.isInGroup(srcTaskId, groupId)) {
+      int idx = srcTaskId - grouper.getFirstTaskInGroup(groupId);
+      return EventRouteMetadata.create(1, new int[] {idx});
+    }
+    return null;
   }
 
   @Override
   public int getNumDestinationTaskPhysicalInputs(int destTaskId) {
-    return 1;
+    int groupId =
+      CartesianProductCombination.fromTaskId(numGroups, destTaskId).getCombination().get(positionId);
+    return grouper.getNumTasksInGroup(groupId);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
index 83caac2..38e2355 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
@@ -47,13 +47,39 @@ import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.CUSTOM;
  * determined by vertex manager.
  */
 public class CartesianProductVertexManager extends VertexManagerPlugin {
+  /**
+   * Begin scheduling task when the fraction of finished cartesian product source tasks reaches
+   * this value
+   */
   public static final String TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION =
     "tez.cartesian-product.min-src-fraction";
   public static final float TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT = 0.25f;
+
+  /**
+   * Schedule all tasks when the fraction of finished cartesian product source tasks reach this value
+   */
   public static final String TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION =
     "tez.cartesian-product.min-src-fraction";
   public static final float TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT = 0.75f;
 
+  /**
+   * Enables automatic grouping. It groups source tasks of each cartesian product source vertex
+   * so that every group generates similar output size. And parallelism can be reduced because
+   * destination tasks handle combinations of per group output instead of per task output. This is
+   * only available for unpartitioned case for now, and it's useful for scenarios where there are
+   * many source tasks generate small outputs.
+   */
+  public static final String TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING =
+    "tez.cartesian-product.enable-auto-grouping";
+  public static final boolean TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING_DEFAULT = true;
+
+  /**
+   * The number of output bytes we want from each group.
+   */
+  public static final String TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP =
+    "tez.cartesian-product.desired-input-per-src";
+  public static final long TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP_DEFAULT = 32 * 1024 * 1024;
+
   private CartesianProductVertexManagerReal vertexManagerReal = null;
 
   public CartesianProductVertexManager(VertexManagerPluginContext context) {

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerConfig.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerConfig.java
index b324524..f43f494 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerConfig.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerConfig.java
@@ -30,10 +30,13 @@ import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUs
 class CartesianProductVertexManagerConfig extends CartesianProductConfig {
   private final float minFraction;
   private final float maxFraction;
+  private final boolean enableAutoGrouping;
+  private final long desiredBytesPerGroup;
 
   public CartesianProductVertexManagerConfig(boolean isPartitioned, String[] sourceVertices,
                                              int[] numPartitions,
                                              float minFraction, float maxFraction,
+                                             boolean enableAutoGrouping, long desiredBytesPerGroup,
                                              CartesianProductFilterDescriptor filterDescriptor) {
     super(isPartitioned, numPartitions, sourceVertices, filterDescriptor);
     Preconditions.checkArgument(minFraction <= maxFraction,
@@ -41,6 +44,8 @@ class CartesianProductVertexManagerConfig extends CartesianProductConfig {
         maxFraction  + ") in cartesian product slow start");
     this.minFraction = minFraction;
     this.maxFraction = maxFraction;
+    this.enableAutoGrouping = enableAutoGrouping;
+    this.desiredBytesPerGroup = desiredBytesPerGroup;
   }
 
   public float getMinFraction() {
@@ -69,7 +74,20 @@ class CartesianProductVertexManagerConfig extends CartesianProductConfig {
     }
     float minFraction = proto.getMinFraction();
     float maxFraction = proto.getMaxFraction();
+
+    boolean enableAutoGrouping = proto.hasEnableAutoGrouping() ? proto.getEnableAutoGrouping()
+      : CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING_DEFAULT;
+    long desiredBytesPerGroup = proto.hasDesiredBytesPerGroup() ? proto.getDesiredBytesPerGroup()
+      : CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP_DEFAULT;
     return new CartesianProductVertexManagerConfig(isPartitioned, sourceVertices, numPartitions,
-      minFraction, maxFraction, filterDescriptor);
+      minFraction, maxFraction, enableAutoGrouping, desiredBytesPerGroup, filterDescriptor);
+  }
+
+  public boolean isEnableAutoGrouping() {
+    return enableAutoGrouping;
+  }
+
+  public long getDesiredBytesPerGroup() {
+    return desiredBytesPerGroup;
   }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
index 38ec1b1..85c04d2 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
@@ -26,6 +26,7 @@ import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
 import org.apache.tez.dag.api.event.VertexState;
 import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -92,6 +93,9 @@ class CartesianProductVertexManagerPartitioned extends CartesianProductVertexMan
   }
 
   @Override
+  public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) throws IOException {}
+
+  @Override
   public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions)
     throws Exception {
     vertexStarted = true;

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerReal.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerReal.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerReal.java
index 84e65ac..1a397fd 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerReal.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerReal.java
@@ -22,6 +22,7 @@ import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
 
+import java.io.IOException;
 import java.util.List;
 
 /**
@@ -40,7 +41,7 @@ abstract class CartesianProductVertexManagerReal {
 
   public abstract void initialize(CartesianProductVertexManagerConfig config) throws Exception;
 
-  public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {}
+  public abstract void onVertexManagerEventReceived(VertexManagerEvent vmEvent) throws IOException;
 
   public abstract void onVertexStarted(List<TaskAttemptIdentifier> completions) throws Exception;
 

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
index 5114293..993cb40 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
@@ -18,6 +18,7 @@
 package org.apache.tez.runtime.library.cartesianproduct;
 
 import com.google.common.primitives.Ints;
+import com.google.protobuf.ByteString;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.UserPayload;
@@ -26,23 +27,33 @@ import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
 import org.apache.tez.dag.api.event.VertexState;
 import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
+import org.apache.tez.runtime.library.utils.Grouper;
 import org.roaringbitmap.RoaringBitmap;
+import org.slf4j.Logger;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.EnumSet;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
+import java.util.Set;
 
 import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.CUSTOM;
 import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.CartesianProductConfigProto;
 
 class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexManagerReal {
+  private static final Logger LOG =
+    org.slf4j.LoggerFactory.getLogger(CartesianProductVertexManagerUnpartitioned.class);
+
   List<String> sourceVertices;
   private int parallelism = 1;
   private boolean vertexReconfigured = false;
@@ -57,6 +68,15 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
   private RoaringBitmap scheduledTasks = new RoaringBitmap();
   private CartesianProductConfig config;
 
+  /* auto reduce related */
+  private int[] numGroups;
+  private Set<String> vertexSentVME = new HashSet<>();
+  private long[] vertexOutputBytes;
+  private int[] numVertexManagerEventsReceived;
+  private long desiredBytesPerGroup;
+  private boolean enableGrouping;
+  private Grouper grouper = new Grouper();
+
   public CartesianProductVertexManagerUnpartitioned(VertexManagerPluginContext context) {
     super(context);
   }
@@ -65,6 +85,16 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
   public void initialize(CartesianProductVertexManagerConfig config) throws Exception {
     sourceVertices = config.getSourceVertices();
     numTasks = new int[sourceVertices.size()];
+    numGroups = new int[sourceVertices.size()];
+    vertexOutputBytes = new long[sourceVertices.size()];
+    numVertexManagerEventsReceived = new int[sourceVertices.size()];
+
+    enableGrouping = config.isEnableAutoGrouping();
+    desiredBytesPerGroup = config.getDesiredBytesPerGroup();
+
+    for (String vertex : sourceVertices) {
+      sourceTaskCompleted.put(vertex, new RoaringBitmap());
+    }
 
     for (String vertex : getContext().getInputVertexEdgeProperties().keySet()) {
       if (sourceVertices.indexOf(vertex) != -1) {
@@ -138,15 +168,72 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
     return true;
   }
 
+  public synchronized void onVertexManagerEventReceived(VertexManagerEvent vmEvent)
+    throws IOException {
+    /* vmEvent after reconfigure doesn't matter */
+    if (vertexReconfigured) {
+      return;
+    }
+
+    if (vmEvent.getUserPayload() != null) {
+      String srcVertex =
+        vmEvent.getProducerAttemptIdentifier().getTaskIdentifier().getVertexIdentifier().getName();
+      int position = sourceVertices.indexOf(srcVertex);
+      // vmEvent from non-cp vertex doesn't matter
+      if (position == -1) {
+        return;
+      }
+      VertexManagerEventPayloadProto proto =
+        VertexManagerEventPayloadProto.parseFrom(ByteString.copyFrom(vmEvent.getUserPayload()));
+      vertexOutputBytes[position] += proto.getOutputSize();
+      numVertexManagerEventsReceived[position]++;
+      vertexSentVME.add(srcVertex);
+    }
+
+    tryScheduleTasks();
+  }
+
   private boolean tryReconfigure() throws IOException {
     if (numCPSrcNotInConfigureState > 0) {
       return false;
     }
+    if (enableGrouping) {
+      if (vertexSentVME.size() != sourceVertices.size()) {
+        return false;
+      }
+      for (int i = 0; i < vertexOutputBytes.length; i++) {
+        if (vertexOutputBytes[i] < desiredBytesPerGroup
+          && numVertexManagerEventsReceived[i] < numTasks[i]) {
+          return false;
+        }
+      }
+    }
+
+    LOG.info("Start reconfigure, grouping: " + enableGrouping
+      + ", group size: " + desiredBytesPerGroup);
+    LOG.info("src vertices: " + sourceVertices);
+    LOG.info("number of source tasks in each src: " + Arrays.toString(numTasks));
+    LOG.info("number of vmEvent from each src: "
+      + Arrays.toString(numVertexManagerEventsReceived));
+    LOG.info("output stats of each src: " + Arrays.toString(vertexOutputBytes));
 
-    for (int numTask : numTasks) {
-      parallelism *= numTask;
+    for (int i = 0; i < numTasks.length; i++) {
+      if (enableGrouping) {
+        vertexOutputBytes[i] =
+          vertexOutputBytes[i] * numTasks[i] / numVertexManagerEventsReceived[i];
+        int desiredNumGroup =
+          (int) ((vertexOutputBytes[i] + desiredBytesPerGroup - 1) / desiredBytesPerGroup);
+        numGroups[i] = Math.min(numTasks[i], desiredNumGroup);
+      } else {
+        numGroups[i] = numTasks[i];
+      }
+      parallelism *= numGroups[i];
     }
 
+    LOG.info("estimated output size of each src: " + Arrays.toString(vertexOutputBytes));
+    LOG.info("number of groups for each src: " + Arrays.toString(numGroups));
+    LOG.info("Final parallelism: " + parallelism);
+
     UserPayload payload = null;
     Map<String, EdgeProperty> edgeProperties = getContext().getInputVertexEdgeProperties();
     Iterator<Map.Entry<String,EdgeProperty>> iter = edgeProperties.entrySet().iterator();
@@ -160,7 +247,7 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
       if (payload == null) {
         CartesianProductConfigProto.Builder builder = CartesianProductConfigProto.newBuilder();
         builder.setIsPartitioned(false).addAllNumTasks(Ints.asList(numTasks))
-          .addAllSourceVertices(config.getSourceVertices());
+          .addAllNumGroups(Ints.asList(numGroups)).addAllSourceVertices(config.getSourceVertices());
         payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
       }
       descriptor.setUserPayload(payload);
@@ -187,21 +274,35 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
   private void scheduledTasksDependOnCompletion(TaskAttemptIdentifier attempt) {
     int taskId = attempt.getTaskIdentifier().getIdentifier();
     String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
+    int position = sourceVertices.indexOf(vertex);
 
     List<ScheduleTaskRequest> requests = new ArrayList<>();
     CartesianProductCombination combination =
-      new CartesianProductCombination(numTasks, sourceVertices.indexOf(vertex));
-    combination.firstTaskWithFixedPartition(taskId);
+      new CartesianProductCombination(numGroups, position);
+    grouper.init(numTasks[position], numGroups[position]);
+    combination.firstTaskWithFixedPartition(grouper.getGroupId(taskId));
     do {
       List<Integer> list = combination.getCombination();
+
+      if (scheduledTasks.contains(combination.getTaskId())) {
+        continue;
+      }
       boolean readyToSchedule = true;
       for (int i = 0; i < list.size(); i++) {
-        if (!sourceTaskCompleted.get(sourceVertices.get(i)).contains(list.get(i))) {
-          readyToSchedule = false;
+        int group = list.get(i);
+        grouper.init(numTasks[i], numGroups[i]);
+        for (int j = grouper.getFirstTaskInGroup(group); j <= grouper.getLastTaskInGroup(group); j++) {
+          if (!sourceTaskCompleted.get(sourceVertices.get(i)).contains(j)) {
+            readyToSchedule = false;
+            break;
+          }
+        }
+        if (!readyToSchedule) {
           break;
         }
       }
-      if (readyToSchedule && !scheduledTasks.contains(combination.getTaskId())) {
+
+      if (readyToSchedule) {
         requests.add(ScheduleTaskRequest.create(combination.getTaskId(), null));
         scheduledTasks.add(combination.getTaskId());
       }

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/utils/Grouper.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/utils/Grouper.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/utils/Grouper.java
new file mode 100644
index 0000000..73a8c87
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/utils/Grouper.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.utils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * This grouper group specified number of tasks into specified number of groups.
+ *
+ * If numTask%numGroup is zero, every group has numTask/numGroup tasks.
+ * Otherwise, every group will get numTask/numGroup tasks first, and remaining tasks will be
+ * distributed in last numTask-numTask%numGroup*numGroup groups (one task for each group).
+ * For example, if we group 8 tasks into 3 groups, each group get {2, 3, 3} tasks.
+ */
+public class Grouper {
+  private int numGroup;
+  private int numTask;
+  private int numGroup1;
+  private int taskPerGroup1;
+  private int numGroup2;
+  private int taskPerGroup2;
+
+  public Grouper init(int numTask, int numGroup) {
+    Preconditions.checkArgument(numGroup > 0,
+      "Number of groups is " + numGroup + ". Should be positive");
+    Preconditions.checkArgument(numTask > 0,
+      "Number of tasks is " + numTask + ". Should be positive");
+    Preconditions.checkArgument(numTask >= numGroup,
+      "Num of groups + " + numGroup + " shouldn't be more than number of tasks " + numTask);
+    this.numTask = numTask;
+    this.numGroup = numGroup;
+    this.taskPerGroup1 = numTask / numGroup;
+    this.taskPerGroup2 = taskPerGroup1 + 1;
+    this.numGroup2 = numTask % numGroup;
+    this.numGroup1 = numGroup - numGroup2;
+
+    return this;
+  }
+
+  public int getFirstTaskInGroup(int groupId) {
+    Preconditions.checkArgument(0 <= groupId && groupId < numGroup, "Invalid groupId " + groupId);
+    if (groupId < numGroup1) {
+      return groupId * taskPerGroup1;
+    } else {
+      return groupId * taskPerGroup1 + (groupId - numGroup1);
+    }
+  }
+
+  public int getNumTasksInGroup(int groupId) {
+    Preconditions.checkArgument(0 <= groupId && groupId < numGroup, "Invalid groupId" + groupId);
+    return groupId < numGroup1 ? taskPerGroup1 : taskPerGroup2;
+  }
+
+  public int getLastTaskInGroup(int groupId) {
+    Preconditions.checkArgument(0 <= groupId && groupId < numGroup, "Invalid groupId" + groupId);
+    return getFirstTaskInGroup(groupId) + getNumTasksInGroup(groupId) - 1;
+  }
+
+  public int getGroupId(int taskId) {
+    Preconditions.checkArgument(0 <= taskId && taskId < numTask, "Invalid taskId" + taskId);
+    if (taskId < taskPerGroup1 * numGroup1) {
+      return taskId/taskPerGroup1;
+    } else {
+      return numGroup1 + (taskId - taskPerGroup1 * numGroup1) / taskPerGroup2;
+    }
+  }
+
+  public boolean isInGroup(int taskId, int groupId) {
+    Preconditions.checkArgument(0 <= groupId && groupId < numGroup, "Invalid groupId" + groupId);
+    Preconditions.checkArgument(0 <= taskId && taskId < numTask, "Invalid taskId" + taskId);
+    return getFirstTaskInGroup(groupId) <= taskId
+      && taskId < getFirstTaskInGroup(groupId) + getNumTasksInGroup(groupId);
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/proto/CartesianProductPayload.proto b/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
index 39ba82c..dd7d06f 100644
--- a/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
+++ b/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
@@ -27,5 +27,8 @@ message CartesianProductConfigProto {
     optional bytes filterUserPayload = 5;
     optional float minFraction = 6;
     optional float maxFraction = 7;
-    repeated int32 numTasks = 8;
+    optional bool enableAutoGrouping = 8;
+    optional int64 desiredBytesPerGroup = 9;
+    repeated int32 numTasks = 10;
+    repeated int32 numGroups = 11;
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
index 4a2827a..06d3e90 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
@@ -107,4 +107,13 @@ public class TestCartesianProductCombination {
       assertTrue(list.get(1) == i%3);
     }
   }
+
+  @Test(timeout = 5000)
+  public void testRejectZero() {
+    int[] numTasks = new int[] {0 ,1};
+    try {
+      new CartesianProductCombination(numTasks);
+      assertTrue(false);
+    } catch (Exception ignored) {}
+  }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
index 2de750f..c9e49a3 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
@@ -20,6 +20,8 @@ package org.apache.tez.runtime.library.cartesianproduct;
 import com.google.common.primitives.Ints;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.CartesianProductConfigProto;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.io.IOException;
@@ -32,11 +34,17 @@ import java.util.Random;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 
 public class TestCartesianProductConfig {
-  private TezConfiguration conf = new TezConfiguration();
+  private TezConfiguration conf;
+
+  @Before
+  public void setup() {
+    conf = new TezConfiguration();
+  }
 
   @Test(timeout = 5000)
   public void testSerializationPartitioned() throws IOException {
@@ -103,4 +111,26 @@ public class TestCartesianProductConfig {
       assertNull(descriptor2);
     }
   }
+
+  @Test(timeout = 5000)
+  public void testAutoGroupingConfig() {
+    List<String> sourceVertices = new ArrayList<>();
+    sourceVertices.add("v0");
+    sourceVertices.add("v1");
+    CartesianProductConfig config = new CartesianProductConfig(sourceVertices);
+
+    // auto grouping conf not set
+    CartesianProductConfigProto proto = config.toProto(conf);
+    assertFalse(proto.hasEnableAutoGrouping());
+    assertFalse(proto.hasDesiredBytesPerGroup());
+
+    // auto groupinig conf not set
+    conf.setBoolean(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING, true);
+    conf.setLong(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP, 1000);
+    proto = config.toProto(conf);
+    assertTrue(proto.hasEnableAutoGrouping());
+    assertTrue(proto.hasDesiredBytesPerGroup());
+    assertEquals(true, proto.getEnableAutoGrouping());
+    assertEquals(1000, proto.getDesiredBytesPerGroup());
+  }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
index 9581a6e..12aee3b 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
@@ -23,9 +23,7 @@ import org.apache.tez.dag.api.UserPayload;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.List;
 
 import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.*;
 import static org.junit.Assert.assertTrue;
@@ -42,8 +40,8 @@ public class TestCartesianProductEdgeManager {
     // partitioned case
     CartesianProductConfigProto.Builder builder = CartesianProductConfigProto.newBuilder();
     builder.setIsPartitioned(true)
-      .addAllSourceVertices(Arrays.asList(new String[]{"v0", "v1"}))
-      .addAllNumPartitions(Ints.asList(new int[]{2,3}));
+      .addAllSourceVertices(Arrays.asList("v0", "v1"))
+      .addAllNumPartitions(Ints.asList(2,3));
     UserPayload payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
     when(context.getUserPayload()).thenReturn(payload);
     edgeManager.initialize();
@@ -51,13 +49,10 @@ public class TestCartesianProductEdgeManager {
       instanceof CartesianProductEdgeManagerPartitioned);
 
     // unpartitioned case
-    List<String> sourceVertices = new ArrayList<>();
-    sourceVertices.add("v0");
-    sourceVertices.add("v1");
     builder.clear();
     builder.setIsPartitioned(false)
-      .addAllSourceVertices(Arrays.asList(new String[]{"v0", "v1"}))
-      .addAllNumTasks(Ints.asList(new int[]{2,3}));
+      .addAllSourceVertices(Arrays.asList("v0", "v1"))
+      .addAllNumTasks(Ints.asList(2,3));
     payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
     when(context.getUserPayload()).thenReturn(payload);
     when(context.getSourceVertexNumTasks()).thenReturn(2);

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerConfig.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerConfig.java
new file mode 100644
index 0000000..9f6fa09
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerConfig.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import com.google.common.primitives.Ints;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.CartesianProductConfigProto;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+
+public class TestCartesianProductEdgeManagerConfig {
+  @Test(timeout = 5000)
+  public void testAutoGroupingConfig() throws IOException {
+    List<String> sourceVertices = new ArrayList<>();
+    sourceVertices.add("v0");
+    sourceVertices.add("v1");
+    int[] numTasks = new int[] {4, 5};
+    int[] numGroups = new int[] {2, 3};
+
+    CartesianProductConfigProto.Builder builder = CartesianProductConfigProto.newBuilder();
+    builder.setIsPartitioned(false).addAllNumTasks(Ints.asList(numTasks))
+      .addAllSourceVertices(sourceVertices).addAllNumGroups(Ints.asList(numGroups));
+    UserPayload payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
+
+    CartesianProductEdgeManagerConfig config =
+      CartesianProductEdgeManagerConfig.fromUserPayload(payload);
+    assertArrayEquals(numGroups, config.getNumGroups());
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
index 09f3b52..1afedb9 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
@@ -51,7 +51,7 @@ public class TestCartesianProductEdgeManagerPartitioned {
   @Test(timeout = 5000)
   public void testTwoWay() throws Exception {
     CartesianProductEdgeManagerConfig emConfig =
-      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1"}, new int[]{3,4}, null, null);
+      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1"}, new int[]{3,4}, null, null, null);
     when(mockContext.getDestinationVertexNumTasks()).thenReturn(12);
     testTwoWayV0(emConfig);
     testTwoWayV1(emConfig);
@@ -145,7 +145,7 @@ public class TestCartesianProductEdgeManagerPartitioned {
       new CartesianProductFilterDescriptor(TestFilter.class.getName())
         .setUserPayload(UserPayload.create(buffer));
     CartesianProductEdgeManagerConfig emConfig =
-      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1"}, new int[]{3,4}, null,
+      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1"}, new int[]{3,4}, null, null,
         filterDescriptor);
     when(mockContext.getDestinationVertexNumTasks()).thenReturn(3);
     testTwoWayV0WithFilter(emConfig);
@@ -206,7 +206,7 @@ public class TestCartesianProductEdgeManagerPartitioned {
   @Test(timeout = 5000)
   public void testThreeWay() throws Exception {
     CartesianProductEdgeManagerConfig emConfig =
-      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1","v2"}, new int[]{4,3,2}, null, null);
+      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1","v2"}, new int[]{4,3,2}, null, null, null);
     when(mockContext.getDestinationVertexNumTasks()).thenReturn(24);
     testThreeWayV0(emConfig);
     testThreeWayV1(emConfig);

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
index ec97335..db781f3 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
@@ -47,7 +47,8 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
   @Test(timeout = 5000)
   public void testTwoWay() throws Exception {
     CartesianProductEdgeManagerConfig emConfig =
-      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1"}, null, new int[]{2,3}, null);
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1"}, null,
+        new int[]{2,3}, new int[]{2,3}, null);
     testTwoWayV0(emConfig);
     testTwoWayV1(emConfig);
   }
@@ -57,7 +58,8 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
     edgeManager.initialize(config);
 
-    CompositeEventRouteMetadata compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    CompositeEventRouteMetadata compositeRoutingData =
+      edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
     assertNull(compositeRoutingData);
 
     compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 3);
@@ -69,11 +71,10 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
     assertNull(routingData);
 
-    compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 3);
-    assertNotNull(compositeRoutingData);
-    assertEquals(1, compositeRoutingData.getCount());
-    assertEquals(0, compositeRoutingData.getTarget());
-    assertEquals(0, compositeRoutingData.getSource());
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 3);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
 
     assertEquals(0, edgeManager.routeInputErrorEventToSource(1, 0));
 
@@ -87,7 +88,8 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
     edgeManager.initialize(config);
 
-    CompositeEventRouteMetadata compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 2);
+    CompositeEventRouteMetadata compositeRoutingData =
+      edgeManager.routeCompositeDataMovementEventToDestination(1, 2);
     assertNull(compositeRoutingData);
 
     compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
@@ -99,11 +101,10 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 2);
     assertNull(routingData);
 
-    compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
-    assertNotNull(compositeRoutingData);
-    assertEquals(1, compositeRoutingData.getCount());
-    assertEquals(0, compositeRoutingData.getTarget());
-    assertEquals(0, compositeRoutingData.getSource());
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
 
     assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 0));
 
@@ -120,7 +121,8 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
   @Test(timeout = 5000)
   public void testThreeWay() throws Exception {
     CartesianProductEdgeManagerConfig emConfig =
-      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1","v2"}, null, new int[]{2,3,4}, null);
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1","v2"}, null,
+        new int[]{2,3,4}, new int[]{2,3,4}, null);
     testThreeWayV0(emConfig);
     testThreeWayV1(emConfig);
     testThreeWayV2(emConfig);
@@ -143,11 +145,10 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
     assertNull(routingData);
 
-    compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 12);
-    assertNotNull(compositeRoutingData);
-    assertEquals(1, compositeRoutingData.getCount());
-    assertEquals(0, compositeRoutingData.getTarget());
-    assertEquals(0, compositeRoutingData.getSource());
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 12);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
 
     assertEquals(0, edgeManager.routeInputErrorEventToSource(1, 0));
 
@@ -173,11 +174,10 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
     assertNull(routingData);
 
-    compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 16);
-    assertNotNull(compositeRoutingData);
-    assertEquals(1, compositeRoutingData.getCount());
-    assertEquals(0, compositeRoutingData.getTarget());
-    assertEquals(0, compositeRoutingData.getSource());
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 16);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
 
     assertEquals(0, edgeManager.routeInputErrorEventToSource(1, 0));
 
@@ -203,11 +203,10 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 0);
     assertNull(routingData);
 
-    compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 13);
-    assertNotNull(compositeRoutingData);
-    assertEquals(1, compositeRoutingData.getCount());
-    assertEquals(0, compositeRoutingData.getTarget());
-    assertEquals(0, compositeRoutingData.getSource());
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 13);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
 
     assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 0));
 
@@ -219,7 +218,8 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
   @Test(timeout = 5000)
   public void testZeroSrcTask() {
     CartesianProductEdgeManagerConfig emConfig =
-      new CartesianProductEdgeManagerConfig(false, new String[]{"v0", "v1"}, null, new int[]{2, 0}, null);
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0", "v1"}, null,
+        new int[]{2,0}, new int[]{2,0}, null);
     testZeroSrcTaskV0(emConfig);
     testZeroSrcTaskV1(emConfig);
   }
@@ -238,4 +238,69 @@ public class TestCartesianProductEdgeManagerUnpartitioned {
     when(mockContext.getSourceVertexNumTasks()).thenReturn(0);
     edgeManager.initialize(config);
   }
+
+  /**
+   * Vertex v0 has 20 tasks 10 groups
+   * Vertex v1 has 10 tasks 1 group
+   */
+  @Test(timeout = 5000)
+  public void testTwoWayAutoGrouping() throws Exception {
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1"}, null,
+        new int[]{20, 10}, new int[]{10,1}, null);
+    testTwoWayAutoGroupingV0(emConfig);
+    testTwoWayAutoGroupingV1(emConfig);
+  }
+
+  private void testTwoWayAutoGroupingV0(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(20);
+    edgeManager.initialize(config);
+
+    CompositeEventRouteMetadata compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNull(compositeRoutingData);
+
+    compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 0);
+    assertNotNull(compositeRoutingData);
+    assertEquals(1, compositeRoutingData.getCount());
+    assertEquals(1, compositeRoutingData.getTarget());
+    assertEquals(0, compositeRoutingData.getSource());
+
+    EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(2, 2);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(2, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+
+    assertEquals(7, edgeManager.routeInputErrorEventToSource(3, 1));
+
+    assertEquals(2, edgeManager.getNumDestinationTaskPhysicalInputs(4));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(5));
+    assertEquals(1, edgeManager.getNumDestinationConsumerTasks(6));
+  }
+
+  private void testTwoWayAutoGroupingV1(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(10);
+    edgeManager.initialize(config);
+
+    CompositeEventRouteMetadata compositeRoutingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(compositeRoutingData);
+    assertEquals(1, compositeRoutingData.getCount());
+    assertEquals(1, compositeRoutingData.getTarget());
+    assertEquals(0, compositeRoutingData.getSource());
+
+    EventRouteMetadata routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(2, 3);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{2}, routingData.getTargetIndices());
+
+    assertEquals(5, edgeManager.routeInputErrorEventToSource(4, 5));
+
+    assertEquals(10, edgeManager.getNumDestinationTaskPhysicalInputs(6));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(7));
+    assertEquals(10, edgeManager.getNumDestinationConsumerTasks(8));
+  }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerConfig.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerConfig.java
new file mode 100644
index 0000000..bf369d9
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerConfig.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.tez.dag.api.TezConfiguration;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+public class TestCartesianProductVertexManagerConfig {
+  @Test(timeout = 5000)
+  public void testAutoGroupingConfig() throws IOException {
+    List<String> sourceVertices = new ArrayList<>();
+    sourceVertices.add("v0");
+    sourceVertices.add("v1");
+    CartesianProductConfig config = new CartesianProductConfig(sourceVertices);
+    TezConfiguration conf = new TezConfiguration();
+
+    // auto group not set in proto
+    CartesianProductVertexManagerConfig vmConf =
+      CartesianProductVertexManagerConfig.fromUserPayload(config.toUserPayload(conf));
+    assertEquals(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING_DEFAULT,
+      vmConf.isEnableAutoGrouping());
+    assertEquals(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP_DEFAULT,
+      vmConf.getDesiredBytesPerGroup());
+
+    // auto group set in proto
+    conf.setBoolean(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_ENABLE_AUTO_GROUPING, true);
+    conf.setLong(CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_DESIRED_BYTES_PER_GROUP, 1000);
+    vmConf = CartesianProductVertexManagerConfig.fromUserPayload(config.toUserPayload(conf));
+    assertEquals(true, vmConf.isEnableAutoGrouping());
+    assertEquals(1000, vmConf.getDesiredBytesPerGroup());
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
index 99067f1..36c0325 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
@@ -69,13 +69,16 @@ public class TestCartesianProductVertexManagerPartitioned {
     setupWithConfig(
       new CartesianProductVertexManagerConfig(true, new String[]{"v0","v1"}, new int[] {2, 2},
         CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT,
-        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT, null));
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT,
+        false, 0, null));
   }
 
   private void setupWithConfig(CartesianProductVertexManagerConfig config)
     throws TezReflectionException {
     MockitoAnnotations.initMocks(this);
     context = mock(VertexManagerPluginContext.class);
+    when(context.getVertexName()).thenReturn("cp");
+    when(context.getVertexNumTasks("cp")).thenReturn(-1);
     vertexManager = new CartesianProductVertexManagerPartitioned(context);
     Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
     edgePropertyMap.put("v0", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
@@ -116,10 +119,10 @@ public class TestCartesianProductVertexManagerPartitioned {
   public void testReconfigureVertex() throws Exception {
     testReconfigureVertexHelper(
       new CartesianProductVertexManagerConfig(true, new String[]{"v0", "v1"}, new int[] {5, 5}, 0,
-        0, new CartesianProductFilterDescriptor(TestFilter.class.getName())), 10);
+        0, false, 0, new CartesianProductFilterDescriptor(TestFilter.class.getName())), 10);
     testReconfigureVertexHelper(
       new CartesianProductVertexManagerConfig(true, new String[]{"v0", "v1"}, new int[] {5, 5}, 0,
-        0, null), 25);
+        0, false, 0, null), 25);
   }
 
   @Test(timeout = 5000)

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
index dfe2830..31a3941 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
@@ -31,6 +31,8 @@ import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.apache.tez.runtime.api.events.VertexManagerEvent;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.VertexManagerEventPayloadProto;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
@@ -39,6 +41,7 @@ import org.mockito.Matchers;
 import org.mockito.MockitoAnnotations;
 
 import java.util.ArrayList;
+import java.util.Formatter;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -86,7 +89,8 @@ public class TestCartesianProductVertexManagerUnpartitioned {
     when(context.getVertexNumTasks(eq("v2"))).thenReturn(5);
 
     CartesianProductVertexManagerConfig config =
-      new CartesianProductVertexManagerConfig(false, new String[]{"v0","v1"}, null, 0, 0, null);
+      new CartesianProductVertexManagerConfig(
+        false, new String[]{"v0","v1"}, null, 0, 0, false, 0, null);
     vertexManager.initialize(config);
     allCompletions = new ArrayList<>();
     allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
@@ -126,6 +130,7 @@ public class TestCartesianProductVertexManagerUnpartitioned {
       CartesianProductEdgeManagerConfig newConfig =
         CartesianProductEdgeManagerConfig.fromUserPayload(payload);
       assertArrayEquals(new int[]{2,3}, newConfig.getNumTasks());
+      assertArrayEquals(new int[]{2,3}, newConfig.getNumGroups());
     }
   }
 
@@ -204,7 +209,8 @@ public class TestCartesianProductVertexManagerUnpartitioned {
     when(context.getVertexNumTasks(eq("v1"))).thenReturn(0);
 
     CartesianProductVertexManagerConfig config =
-      new CartesianProductVertexManagerConfig(false, new String[]{"v0","v1"}, null, 0, 0, null);
+      new CartesianProductVertexManagerConfig(
+        false, new String[]{"v0","v1"}, null, 0, 0, false, 0, null);
     Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
     edgePropertyMap.put("v0", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
       CartesianProductEdgeManager.class.getName()), null, null, null, null));
@@ -213,10 +219,143 @@ public class TestCartesianProductVertexManagerUnpartitioned {
     when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
 
     vertexManager.initialize(config);
+    allCompletions = new ArrayList<>();
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 0), 0), 0)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 0), 1), 0)));
+
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.onSourceTaskCompleted(allCompletions.get(0));
+    vertexManager.onSourceTaskCompleted(allCompletions.get(1));
+  }
+
+  @Test(timeout = 5000)
+  public void testAutoGrouping() throws Exception {
+    testAutoGroupingHelper(false);
+    testAutoGroupingHelper(true);
+  }
+
+  private void testAutoGroupingHelper(boolean enableAutoGrouping) throws Exception {
+    int numTaskV0 = 20;
+    int numTaskV1 = 10;
+    long desiredBytesPerGroup = 1000;
+    long outputBytesPerTaskV0 = 500;
+    long outputBytesPerTaskV1 = 10;
+    int expectedNumGroupV0 = 10;
+    int expectedNumGroupV1 = 1;
+    ArgumentCaptor<Integer> parallelismCaptor = ArgumentCaptor.forClass(Integer.class);
+    CartesianProductVertexManagerConfig config = new CartesianProductVertexManagerConfig(
+      false, new String[]{"v0","v1"}, null, 0, 0, enableAutoGrouping, desiredBytesPerGroup, null);
+    Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
+    EdgeProperty edgeProperty = EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+      CartesianProductEdgeManager.class.getName()), null, null, null, null);
+    edgePropertyMap.put("v0", edgeProperty);
+    edgePropertyMap.put("v1", edgeProperty);
+    edgePropertyMap.put("v2", EdgeProperty.create(BROADCAST, null, null, null, null));
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
+    when(context.getVertexNumTasks(eq("v0"))).thenReturn(2);
+    when(context.getVertexNumTasks(eq("v1"))).thenReturn(3);
+
+    context = mock(VertexManagerPluginContext.class);
+    vertexManager = new CartesianProductVertexManagerUnpartitioned(context);
+    when(context.getVertexNumTasks(eq("v0"))).thenReturn(numTaskV0);
+    when(context.getVertexNumTasks(eq("v1"))).thenReturn(numTaskV1);
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
+
+    vertexManager.initialize(config);
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
+    if (!enableAutoGrouping) {
+      // auto grouping disabled, shouldn't auto group
+      verify(context, times(1)).reconfigureVertex(parallelismCaptor.capture(),
+        isNull(VertexLocationHint.class), edgePropertiesCaptor.capture());
+      assertEquals(numTaskV0 * numTaskV1, parallelismCaptor.getValue().intValue());
+      return;
+    }
+
+    // not enough input size, shouldn't auto group
+    verify(context, never()).reconfigureVertex(anyInt(), any(VertexLocationHint.class),
+      anyMapOf(String.class, EdgeProperty.class));
+
+    // only v0 reach threshold or finish all task, shouldn't auto group
+    VertexManagerEventPayloadProto.Builder builder = VertexManagerEventPayloadProto.newBuilder();
+    builder.setOutputSize(outputBytesPerTaskV0);
+    VertexManagerEventPayloadProto proto = builder.build();
+    VertexManagerEvent vmEvent =
+      VertexManagerEvent.create("cp vertex", proto.toByteString().asReadOnlyByteBuffer());
+
+    Formatter formatter = new Formatter();
+    for (int i = 0; i < desiredBytesPerGroup/outputBytesPerTaskV0; i++) {
+      vmEvent.setProducerAttemptIdentifier(
+        new TaskAttemptIdentifierImpl("dag", "v0", TezTaskAttemptID.fromString(
+          formatter.format("attempt_1441301219877_0109_1_00_%06d_0", i).toString())));
+      vertexManager.onVertexManagerEventReceived(vmEvent);
+    }
+    verify(context, never()).reconfigureVertex(anyInt(), any(VertexLocationHint.class),
+      anyMapOf(String.class, EdgeProperty.class));
+
+    // vmEvent from broadcast vertex shouldn't matter
+    vmEvent.setProducerAttemptIdentifier(new TaskAttemptIdentifierImpl("dag", "v2",
+        TezTaskAttemptID.fromString("attempt_1441301219877_0109_1_00_000000_0")));
+    vertexManager.onVertexManagerEventReceived(vmEvent);
+
+    // v1 finish all tasks but still doesn't reach threshold, auto group anyway
+    proto = builder.setOutputSize(outputBytesPerTaskV1).build();
+    vmEvent = VertexManagerEvent.create("cp vertex", proto.toByteString().asReadOnlyByteBuffer());
+    for (int i = 0; i < numTaskV1; i++) {
+      verify(context, never()).reconfigureVertex(anyInt(), any(VertexLocationHint.class),
+        anyMapOf(String.class, EdgeProperty.class));
+      vmEvent.setProducerAttemptIdentifier(
+        new TaskAttemptIdentifierImpl("dag", "v1", TezTaskAttemptID.fromString(
+          formatter.format("attempt_1441301219877_0109_1_01_%06d_0", i).toString())));
+      vertexManager.onVertexManagerEventReceived(vmEvent);
+    }
+    formatter.close();
+    verify(context, times(1)).reconfigureVertex(parallelismCaptor.capture(),
+      isNull(VertexLocationHint.class), edgePropertiesCaptor.capture());
+    Map<String, EdgeProperty> edgeProperties = edgePropertiesCaptor.getValue();
+    for (EdgeProperty property : edgeProperties.values()) {
+      UserPayload payload = property.getEdgeManagerDescriptor().getUserPayload();
+      CartesianProductEdgeManagerConfig newConfig =
+        CartesianProductEdgeManagerConfig.fromUserPayload(payload);
+      assertArrayEquals(new int[]{numTaskV0, numTaskV1}, newConfig.getNumTasks());
+      assertArrayEquals(new int[]{expectedNumGroupV0,expectedNumGroupV1}, newConfig.getNumGroups());
+    }
+
+    assertEquals(expectedNumGroupV0 * expectedNumGroupV1, parallelismCaptor.getValue().intValue());
+    for (EdgeProperty property : edgePropertiesCaptor.getValue().values()) {
+      CartesianProductEdgeManagerConfig emConfig =
+        CartesianProductEdgeManagerConfig.fromUserPayload(
+          property.getEdgeManagerDescriptor().getUserPayload());
+      assertArrayEquals(new int[] {numTaskV0, numTaskV1}, emConfig.getNumTasks());
+      assertArrayEquals(new int[] {expectedNumGroupV0, expectedNumGroupV1}, emConfig.getNumGroups());
+    }
+
     vertexManager.onVertexStarted(null);
+    // v0 t0 finish, shouldn't schedule
     vertexManager.onSourceTaskCompleted(allCompletions.get(0));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+
+    // v1 all task finish, shouldn't schedule
+    for (int i = 0; i < numTaskV1; i++) {
+      vertexManager.onSourceTaskCompleted(new TaskAttemptIdentifierImpl("dag", "v1",
+        TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+          TezDAGID.getInstance("0", 0, 0), 1), i), 0)));
+      verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    }
+
+    // v0 t1 finish, should schedule
     vertexManager.onSourceTaskCompleted(allCompletions.get(1));
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
+    assertNotNull(requests);
+    assertEquals(1, requests.size());
+    assertEquals(0, requests.get(0).getTaskIndex());
   }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/506c9bc3/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestGrouper.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestGrouper.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestGrouper.java
new file mode 100644
index 0000000..481bd7e
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestGrouper.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.tez.runtime.library.utils.Grouper;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TestGrouper {
+  private Grouper grouper = new Grouper();
+
+  @Test(timeout = 5000)
+  public void testEvenlyGrouping() {
+    grouper.init(4, 2);
+    assertEquals(0, grouper.getFirstTaskInGroup(0));
+    assertEquals(2, grouper.getFirstTaskInGroup(1));
+    assertEquals(2, grouper.getNumTasksInGroup(0));
+    assertEquals(2, grouper.getNumTasksInGroup(1));
+    assertEquals(1, grouper.getLastTaskInGroup(0));
+    assertEquals(3, grouper.getLastTaskInGroup(1));
+    assertEquals(0, grouper.getGroupId(1));
+    assertEquals(1, grouper.getGroupId(2));
+    assertTrue(grouper.isInGroup(2, 1));
+    assertFalse(grouper.isInGroup(2, 0));
+  }
+
+  @Test(timeout = 5000)
+  public void testUnevenlyGrouping() {
+    grouper.init(5, 2);
+    assertEquals(0, grouper.getFirstTaskInGroup(0));
+    assertEquals(2, grouper.getFirstTaskInGroup(1));
+    assertEquals(2, grouper.getNumTasksInGroup(0));
+    assertEquals(3, grouper.getNumTasksInGroup(1));
+    assertEquals(1, grouper.getLastTaskInGroup(0));
+    assertEquals(4, grouper.getLastTaskInGroup(1));
+    assertEquals(0, grouper.getGroupId(1));
+    assertEquals(1, grouper.getGroupId(3));
+    assertTrue(grouper.isInGroup(3, 1));
+    assertFalse(grouper.isInGroup(3, 0));
+  }
+
+  @Test(timeout = 5000)
+  public void testSingleGroup() {
+    grouper.init(4, 1);
+    assertEquals(0, grouper.getFirstTaskInGroup(0));
+    assertEquals(4, grouper.getNumTasksInGroup(0));
+    assertEquals(3, grouper.getLastTaskInGroup(0));
+    assertEquals(0, grouper.getGroupId(0));
+    assertEquals(0, grouper.getGroupId(3));
+    assertTrue(grouper.isInGroup(3, 0));
+  }
+
+  @Test(timeout = 5000)
+  public void testNoGrouping() {
+    grouper.init(2, 2);
+    assertEquals(0, grouper.getFirstTaskInGroup(0));
+    assertEquals(1, grouper.getNumTasksInGroup(0));
+    assertEquals(0, grouper.getLastTaskInGroup(0));
+    assertEquals(0, grouper.getGroupId(0));
+    assertTrue(grouper.isInGroup(0, 0));
+  }
+}


Mime
View raw message