tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From min...@apache.org
Subject [1/2] tez git commit: TEZ-3230. Implement vertex manager and edge manager of cartesian product edge. (Zhiyuan Yang via mingma)
Date Tue, 06 Sep 2016 17:51:14 GMT
Repository: tez
Updated Branches:
  refs/heads/master af8246931 -> 1a068b239


http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
new file mode 100644
index 0000000..af7d15e
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
@@ -0,0 +1,178 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Ints;
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+
+import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.CartesianProductConfigProto;
+
+class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexManagerReal {
+  List<String> sourceVertices;
+  private int parallelism = 1;
+  private boolean vertexStarted = false;
+  private boolean vertexReconfigured = false;
+  private int numSourceVertexConfigured = 0;
+  private int[] numTasks;
+  private Queue<TaskAttemptIdentifier> pendingCompletedSrcTask = new LinkedList<>();
+  private Map<String, BitSet> sourceTaskCompleted = new HashMap<>();
+  private BitSet scheduledTasks = new BitSet();
+  private CartesianProductConfig config;
+  private int numSrcHasCompletedTask = 0;
+
+  public CartesianProductVertexManagerUnpartitioned(VertexManagerPluginContext context) {
+    super(context);
+  }
+
+  @Override
+  public void initialize(CartesianProductVertexManagerConfig config) throws Exception {
+    sourceVertices = config.getSourceVertices();
+    numTasks = new int[sourceVertices.size()];
+    for (String vertex : sourceVertices) {
+      sourceTaskCompleted.put(vertex, new BitSet());
+    }
+    for (String vertex : sourceVertices) {
+      getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.CONFIGURED));
+    }
+    this.config = config;
+    getContext().vertexReconfigurationPlanned();
+  }
+
+  private void reconfigureVertex() throws IOException {
+    for (int numTask : numTasks) {
+      parallelism *= numTask;
+    }
+
+    UserPayload payload = null;
+    Map<String, EdgeProperty> edgeProperties = getContext().getInputVertexEdgeProperties();
+    for (EdgeProperty edgeProperty : edgeProperties.values()) {
+      EdgeManagerPluginDescriptor descriptor = edgeProperty.getEdgeManagerDescriptor();
+      if (payload == null) {
+        CartesianProductConfigProto.Builder builder = CartesianProductConfigProto.newBuilder();
+        builder.setIsPartitioned(false).addAllNumTasks(Ints.asList(numTasks))
+          .addAllSourceVertices(config.getSourceVertices());
+        payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
+      }
+      descriptor.setUserPayload(payload);
+    }
+    getContext().reconfigureVertex(parallelism, null, edgeProperties);
+    vertexReconfigured = true;
+    getContext().doneReconfiguringVertex();
+  }
+
+  @Override
+  public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions)
+    throws Exception {
+    vertexStarted = true;
+    // if vertex is already reconfigured, we can handle pending completions immediately
+    // otherwise we have to wait until vertex is reconfigured
+    if (vertexReconfigured) {
+      Preconditions.checkArgument(pendingCompletedSrcTask.size() == 0,
+        "Unexpected pending source completion on vertex start after vertex reconfiguration");
+      for (TaskAttemptIdentifier taId : completions) {
+        handleCompletedSrcTask(taId);
+      }
+    } else {
+      pendingCompletedSrcTask.addAll(completions);
+    }
+  }
+
+  @Override
+  public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws IOException {
+    Preconditions.checkArgument(stateUpdate.getVertexState() == VertexState.CONFIGURED);
+    String vertex = stateUpdate.getVertexName();
+    numTasks[sourceVertices.indexOf(vertex)] = getContext().getVertexNumTasks(vertex);
+    // reconfigure vertex when all source vertices are CONFIGURED
+    if (++numSourceVertexConfigured == sourceVertices.size()) {
+      reconfigureVertex();
+      // handle pending source completions when vertex is started and reconfigured
+      if (vertexStarted) {
+        while (!pendingCompletedSrcTask.isEmpty()) {
+          handleCompletedSrcTask(pendingCompletedSrcTask.poll());
+        }
+      }
+    }
+  }
+
+  @Override
+  public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) throws Exception {
+    if (numSourceVertexConfigured < sourceVertices.size()) {
+      pendingCompletedSrcTask.add(attempt);
+      return;
+    }
+    Preconditions.checkArgument(pendingCompletedSrcTask.size() == 0,
+      "Unexpected pending src completion on source task completed after vertex reconfiguration");
+    handleCompletedSrcTask(attempt);
+  }
+
+  private void handleCompletedSrcTask(TaskAttemptIdentifier attempt) {
+    int taskId = attempt.getTaskIdentifier().getIdentifier();
+    String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
+    if (sourceTaskCompleted.get(vertex).get(taskId)) {
+      return;
+    }
+
+    if (sourceTaskCompleted.get(vertex).isEmpty()) {
+      numSrcHasCompletedTask++;
+    }
+    sourceTaskCompleted.get(vertex).set(taskId);
+    if (numSrcHasCompletedTask != sourceVertices.size()) {
+      return;
+    }
+
+    List<ScheduleTaskRequest> requests = new ArrayList<>();
+    CartesianProductCombination combination = new CartesianProductCombination(numTasks, sourceVertices.indexOf(vertex));
+    combination.firstTaskWithFixedPartition(taskId);
+    do {
+      List<Integer> list = combination.getCombination();
+      boolean readyToSchedule = true;
+      for (int i = 0; i < list.size(); i++) {
+        if (!sourceTaskCompleted.get(sourceVertices.get(i)).get(list.get(i))) {
+          readyToSchedule = false;
+          break;
+        }
+      }
+      if (readyToSchedule && !scheduledTasks.get(combination.getTaskId())) {
+        requests.add(ScheduleTaskRequest.create(combination.getTaskId(), null));
+        scheduledTasks.set(combination.getTaskId());
+      }
+    } while (combination.nextTaskWithFixedPartition());
+    if (!requests.isEmpty()) {
+      getContext().scheduleTasks(requests);
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/proto/CartesianProductPayload.proto b/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
new file mode 100644
index 0000000..39ba82c
--- /dev/null
+++ b/tez-runtime-library/src/main/proto/CartesianProductPayload.proto
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+option java_package = "org.apache.tez.runtime.library.cartesianproduct";
+option java_outer_classname = "CartesianProductUserPayload";
+
+message CartesianProductConfigProto {
+    required bool isPartitioned = 1;
+    repeated string sourceVertices = 2;
+    repeated int32 numPartitions = 3;
+    optional string filterClassName = 4;
+    optional bytes filterUserPayload = 5;
+    optional float minFraction = 6;
+    optional float maxFraction = 7;
+    repeated int32 numTasks = 8;
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
new file mode 100644
index 0000000..0d6a928
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import com.google.common.primitives.Ints;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TestCartesianProductCombination {
+  private void verifyCombination(CartesianProductCombination combination, int[] result, int taskId) {
+    assertArrayEquals(result, Ints.toArray(combination.getCombination()));
+    assertEquals(taskId, combination.getTaskId());
+  }
+
+  private void testCombinationTwoWayVertex0() {
+    CartesianProductCombination combination = new CartesianProductCombination(new int[]{2,3}, 0);
+
+    combination.firstTaskWithFixedPartition(1);
+    verifyCombination(combination, new int[]{1,0}, 3);
+    assertTrue(combination.nextTaskWithFixedPartition());
+    verifyCombination(combination, new int[]{1,1}, 4);
+    assertTrue(combination.nextTaskWithFixedPartition());
+    verifyCombination(combination, new int[]{1,2}, 5);
+    assertFalse(combination.nextTaskWithFixedPartition());
+  }
+
+  private void testCombinationTwoWayVertex1() {
+    CartesianProductCombination combination = new CartesianProductCombination(new int[]{2,3}, 1);
+
+    combination.firstTaskWithFixedPartition(1);
+    verifyCombination(combination, new int[]{0,1}, 1);
+    assertTrue(combination.nextTaskWithFixedPartition());
+    verifyCombination(combination, new int[]{1,1}, 4);
+
+    assertFalse(combination.nextTaskWithFixedPartition());
+  }
+
+  private void testCombinationThreeWay() {
+    CartesianProductCombination combination = new CartesianProductCombination(new int[]{2,2,2}, 1);
+
+    combination.firstTaskWithFixedPartition(1);
+    verifyCombination(combination, new int[]{0,1,0}, 2);
+    assertTrue(combination.nextTaskWithFixedPartition());
+    verifyCombination(combination, new int[]{0,1,1}, 3);
+    assertTrue(combination.nextTaskWithFixedPartition());
+    verifyCombination(combination, new int[]{1,1,0}, 6);
+    assertTrue(combination.nextTaskWithFixedPartition());
+    verifyCombination(combination, new int[]{1,1,1}, 7);
+    assertFalse(combination.nextTaskWithFixedPartition());
+  }
+
+  @Test(timeout = 5000)
+  public void testCombinationWithFixedPartition() {
+    // two way cartesian product
+    testCombinationTwoWayVertex0();
+    testCombinationTwoWayVertex1();
+
+    // three way cartesian product
+    testCombinationThreeWay();
+  }
+
+  @Test(timeout = 5000)
+  public void testCombination() {
+    CartesianProductCombination combination = new CartesianProductCombination(new int[]{2,3});
+    List<Integer> list = combination.getCombination();
+    for (int i = 0; i < 2; i++) {
+      for (int j = 0; j < 3; j++) {
+        if (i == 0 && j == 0) {
+          combination.firstTask();
+        } else {
+          assertTrue(combination.nextTask());
+        }
+        assertTrue(list.get(0) == i);
+        assertTrue(list.get(1) == j);
+      }
+    }
+    assertFalse(combination.nextTask());
+  }
+
+  @Test//(timeout = 5000)
+  public void testFromTaskId() {
+    for (int i = 0; i < 6; i++) {
+      List<Integer> list = CartesianProductCombination.fromTaskId(new int[]{2,3}, i)
+                                                      .getCombination();
+      assertTrue(list.get(0) == i/3);
+      assertTrue(list.get(1) == i%3);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
new file mode 100644
index 0000000..2de750f
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductConfig.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import com.google.common.primitives.Ints;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.UserPayload;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+
+public class TestCartesianProductConfig {
+  private TezConfiguration conf = new TezConfiguration();
+
+  @Test(timeout = 5000)
+  public void testSerializationPartitioned() throws IOException {
+    Map<String, Integer> vertexPartitionMap = new HashMap<>();
+    vertexPartitionMap.put("v1", 2);
+    vertexPartitionMap.put("v2", 3);
+    vertexPartitionMap.put("v3", 4);
+    String filterClassName = "filter";
+    byte[] bytes = new byte[10];
+    (new Random()).nextBytes(bytes);
+    CartesianProductFilterDescriptor filterDescriptor =
+      new CartesianProductFilterDescriptor(filterClassName)
+        .setUserPayload(UserPayload.create(ByteBuffer.wrap(bytes)));
+    CartesianProductConfig config =
+      new CartesianProductConfig(vertexPartitionMap, filterDescriptor);
+    UserPayload payload = config.toUserPayload(conf);
+    CartesianProductConfig parsedConfig = CartesianProductConfig.fromUserPayload(payload);
+    assertConfigEquals(config, parsedConfig);
+  }
+
+  @Test(timeout = 5000)
+  public void testSerializationUnpartitioned() throws Exception {
+    List<String> sourceVertices = new ArrayList<>();
+    sourceVertices.add("v1");
+    sourceVertices.add("v2");
+    sourceVertices.add("v3");
+    CartesianProductConfig config =
+      new CartesianProductConfig(sourceVertices);
+    UserPayload payload = config.toUserPayload(conf);
+    CartesianProductConfig parsedConfig = CartesianProductConfig.fromUserPayload(payload);
+    assertConfigEquals(config, parsedConfig);
+
+    // unpartitioned config should have null in numPartitions fields
+    try {
+      config = new CartesianProductConfig(false, new int[]{}, new String[]{"v0","v1"},null);
+      config.checkNumPartitions();
+    } catch (Exception e) {
+      return;
+    }
+    throw new Exception();
+  }
+
+  private void assertConfigEquals(CartesianProductConfig config1, CartesianProductConfig config2) {
+    assertArrayEquals(config1.getSourceVertices().toArray(new String[0]),
+      config2.getSourceVertices().toArray(new String[0]));
+    if (config1.getNumPartitions() == null) {
+      assertNull(config2.getNumPartitions());
+    } else {
+      assertArrayEquals(Ints.toArray(config1.getNumPartitions()),
+        Ints.toArray(config2.getNumPartitions()));
+    }
+    CartesianProductFilterDescriptor descriptor1 = config1.getFilterDescriptor();
+    CartesianProductFilterDescriptor descriptor2 = config1.getFilterDescriptor();
+
+    if (descriptor1 != null && descriptor2 != null) {
+      assertEquals(descriptor1.getClassName(), descriptor2.getClassName());
+      UserPayload payload1 = descriptor1.getUserPayload();
+      UserPayload payload2 = descriptor2.getUserPayload();
+      if (payload1 != null && payload2 != null) {
+        assertEquals(0, payload1.getPayload().compareTo(payload2.getPayload()));
+      }
+    } else {
+      assertNull(descriptor1);
+      assertNull(descriptor2);
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
new file mode 100644
index 0000000..9581a6e
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManager.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import com.google.common.primitives.Ints;
+import org.apache.tez.dag.api.EdgeManagerPluginContext;
+import org.apache.tez.dag.api.UserPayload;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.*;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TestCartesianProductEdgeManager {
+  @Test(timeout = 5000)
+  public void testInitialize() throws Exception {
+    EdgeManagerPluginContext context = mock(EdgeManagerPluginContext.class);
+    when(context.getSourceVertexName()).thenReturn("v0");
+    CartesianProductEdgeManager edgeManager = new CartesianProductEdgeManager(context);
+
+    // partitioned case
+    CartesianProductConfigProto.Builder builder = CartesianProductConfigProto.newBuilder();
+    builder.setIsPartitioned(true)
+      .addAllSourceVertices(Arrays.asList(new String[]{"v0", "v1"}))
+      .addAllNumPartitions(Ints.asList(new int[]{2,3}));
+    UserPayload payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
+    when(context.getUserPayload()).thenReturn(payload);
+    edgeManager.initialize();
+    assertTrue(edgeManager.getEdgeManagerReal()
+      instanceof CartesianProductEdgeManagerPartitioned);
+
+    // unpartitioned case
+    List<String> sourceVertices = new ArrayList<>();
+    sourceVertices.add("v0");
+    sourceVertices.add("v1");
+    builder.clear();
+    builder.setIsPartitioned(false)
+      .addAllSourceVertices(Arrays.asList(new String[]{"v0", "v1"}))
+      .addAllNumTasks(Ints.asList(new int[]{2,3}));
+    payload = UserPayload.create(ByteBuffer.wrap(builder.build().toByteArray()));
+    when(context.getUserPayload()).thenReturn(payload);
+    when(context.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize();
+    assertTrue(edgeManager.getEdgeManagerReal()
+      instanceof CartesianProductEdgeManagerUnpartitioned);
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
new file mode 100644
index 0000000..2e8697d
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
@@ -0,0 +1,284 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.tez.dag.api.EdgeManagerPluginContext;
+import org.apache.tez.dag.api.EdgeManagerPluginOnDemand.EventRouteMetadata;
+import org.apache.tez.dag.api.UserPayload;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TestCartesianProductEdgeManagerPartitioned {
+  private EdgeManagerPluginContext mockContext;
+  private CartesianProductEdgeManagerPartitioned edgeManager;
+
+  @Before
+  public void setup() {
+    mockContext = mock(EdgeManagerPluginContext.class);
+    edgeManager = new CartesianProductEdgeManagerPartitioned(mockContext);
+  }
+
+  /**
+   * Vertex v0 has 2 tasks which generate 3 partitions
+   * Vertex v1 has 3 tasks which generate 4 partitions
+   */
+  @Test(timeout = 5000)
+  public void testTwoWay() throws Exception {
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1"}, new int[]{3,4}, null, null);
+    when(mockContext.getDestinationVertexNumTasks()).thenReturn(12);
+    testTwoWayV0(emConfig);
+    testTwoWayV1(emConfig);
+  }
+
+  private void testTwoWayV0(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeDataMovementEventToDestination(1,0,1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeDataMovementEventToDestination(1,1,1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 1));
+
+    assertEquals(12, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(2, edgeManager.getNumDestinationTaskPhysicalInputs(10));
+    assertEquals(3, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+
+  private void testTwoWayV1(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(2, edgeManager.routeInputErrorEventToSource(1, 2));
+
+    assertEquals(12, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(3, edgeManager.getNumDestinationTaskPhysicalInputs(10));
+    assertEquals(4, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+
+  public static class TestFilter extends CartesianProductFilter {
+    char op;
+
+    public TestFilter(UserPayload payload) {
+      super(payload);
+      op = payload.getPayload().getChar();
+    }
+
+    @Override
+    public boolean isValidCombination(Map<String, Integer> vertexPartitionMap) {
+      switch (op) {
+        case '>':
+          return vertexPartitionMap.get("v0") > vertexPartitionMap.get("v1");
+        case '<':
+          return vertexPartitionMap.get("v0") < vertexPartitionMap.get("v1");
+        default:
+          return true;
+      }
+    }
+  }
+
+  /**
+   * Vertex v0 has 2 tasks which generate 3 partitions
+   * Vertex v1 has 3 tasks which generate 4 partitions
+   */
+  @Test//(timeout = 5000)
+  public void testTwoWayWithFilter() throws Exception {
+    ByteBuffer buffer = ByteBuffer.allocate(2);
+    buffer.putChar('>');
+    buffer.flip();
+    CartesianProductFilterDescriptor filterDescriptor =
+      new CartesianProductFilterDescriptor(TestFilter.class.getName())
+        .setUserPayload(UserPayload.create(buffer));
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1"}, new int[]{3,4}, null,
+        filterDescriptor);
+    when(mockContext.getDestinationVertexNumTasks()).thenReturn(3);
+    testTwoWayV0WithFilter(emConfig);
+    testTwoWayV1WithFilter(emConfig);
+  }
+
+  private void testTwoWayV0WithFilter(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{2}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 1));
+
+    assertEquals(3, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(2, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    assertEquals(3, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+
+  private void testTwoWayV1WithFilter(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(2, edgeManager.routeInputErrorEventToSource(1, 2));
+
+    assertEquals(3, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(3, edgeManager.getNumDestinationTaskPhysicalInputs(10));
+    assertEquals(4, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+
+  /**
+   * Vertex v0 has 2 tasks which generate 4 partitions
+   * Vertex v1 has 3 tasks which generate 3 partitions
+   * Vertex v2 has 4 tasks which generate 2 partitions
+   */
+  @Test(timeout = 5000)
+  public void testThreeWay() throws Exception {
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(true, new String[]{"v0","v1","v2"}, new int[]{4,3,2}, null, null);
+    when(mockContext.getDestinationVertexNumTasks()).thenReturn(24);
+    testThreeWayV0(emConfig);
+    testThreeWayV1(emConfig);
+    testThreeWayV2(emConfig);
+  }
+
+  private void testThreeWayV0(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 1));
+
+    assertEquals(24, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(2, edgeManager.getNumDestinationTaskPhysicalInputs(10));
+    assertEquals(4, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+
+  private void testThreeWayV1(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(2, edgeManager.routeInputErrorEventToSource(1, 2));
+
+    assertEquals(24, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(3, edgeManager.getNumDestinationTaskPhysicalInputs(10));
+    assertEquals(3, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+
+  private void testThreeWayV2(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v2");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(4);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getSourceIndices());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{1}, routingData.getTargetIndices());
+
+    assertEquals(2, edgeManager.routeInputErrorEventToSource(1, 2));
+
+    assertEquals(24, edgeManager.getNumDestinationConsumerTasks(1));
+    assertEquals(4, edgeManager.getNumDestinationTaskPhysicalInputs(10));
+    assertEquals(2, edgeManager.getNumSourceTaskPhysicalOutputs(2));
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
new file mode 100644
index 0000000..4c69482
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerUnpartitioned.java
@@ -0,0 +1,240 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.tez.dag.api.EdgeManagerPluginContext;
+import org.apache.tez.dag.api.EdgeManagerPluginOnDemand.EventRouteMetadata;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TestCartesianProductEdgeManagerUnpartitioned {
+  private EdgeManagerPluginContext mockContext;
+  private CartesianProductEdgeManagerUnpartitioned edgeManager;
+
+  @Before
+  public void setup() {
+    mockContext = mock(EdgeManagerPluginContext.class);
+    edgeManager = new CartesianProductEdgeManagerUnpartitioned(mockContext);
+  }
+
+  /**
+   * Vertex v0 has 2 tasks
+   * Vertex v1 has 3 tasks
+   */
+  @Test(timeout = 5000)
+  public void testTwoWay() throws Exception {
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1"}, null, new int[]{2,3}, null);
+    testTwoWayV0(emConfig);
+    testTwoWayV1(emConfig);
+  }
+
+  private void testTwoWayV0(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 3);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 3);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    assertEquals(0, edgeManager.routeInputErrorEventToSource(1, 0));
+
+    assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(1));
+    assertEquals(3, edgeManager.getNumDestinationConsumerTasks(1));
+  }
+
+  private void testTwoWayV1(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 2);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 2);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 0));
+
+    assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(1));
+    assertEquals(2, edgeManager.getNumDestinationConsumerTasks(1));
+  }
+
+  /**
+   * Vertex v0 has 2 tasks
+   * Vertex v1 has 3 tasks
+   * Vertex v2 has 4 tasks
+   */
+  @Test(timeout = 5000)
+  public void testThreeWay() throws Exception {
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0","v1","v2"}, null, new int[]{2,3,4}, null);
+    testThreeWayV0(emConfig);
+    testThreeWayV1(emConfig);
+    testThreeWayV2(emConfig);
+  }
+
+  private void testThreeWayV0(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 12);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 12);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    assertEquals(0, edgeManager.routeInputErrorEventToSource(1, 0));
+
+    assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(1));
+    assertEquals(12, edgeManager.getNumDestinationConsumerTasks(1));
+  }
+
+  private void testThreeWayV1(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(3);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 16);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 1);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 16);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    assertEquals(0, edgeManager.routeInputErrorEventToSource(1, 0));
+
+    assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(1));
+    assertEquals(8, edgeManager.getNumDestinationConsumerTasks(1));
+  }
+
+  private void testThreeWayV2(CartesianProductEdgeManagerConfig config) throws Exception {
+    when(mockContext.getSourceVertexName()).thenReturn("v2");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(4);
+    edgeManager.initialize(config);
+
+    EventRouteMetadata routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 0);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 13);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    routingData = edgeManager.routeInputSourceTaskFailedEventToDestination(1, 0);
+    assertNull(routingData);
+
+    routingData = edgeManager.routeCompositeDataMovementEventToDestination(1, 13);
+    assertNotNull(routingData);
+    assertEquals(1, routingData.getNumEvents());
+    assertArrayEquals(new int[]{0}, routingData.getTargetIndices());
+    assertArrayEquals(new int[]{0}, routingData.getSourceIndices());
+
+    assertEquals(1, edgeManager.routeInputErrorEventToSource(1, 0));
+
+    assertEquals(1, edgeManager.getNumDestinationTaskPhysicalInputs(1));
+    assertEquals(1, edgeManager.getNumSourceTaskPhysicalOutputs(1));
+    assertEquals(6, edgeManager.getNumDestinationConsumerTasks(1));
+  }
+
+  @Test(timeout = 5000)
+  public void testZeroSrcTask() {
+    CartesianProductEdgeManagerConfig emConfig =
+      new CartesianProductEdgeManagerConfig(false, new String[]{"v0", "v1"}, null, new int[]{2, 0}, null);
+    testZeroSrcTaskV0(emConfig);
+    testZeroSrcTaskV1(emConfig);
+  }
+
+  private void testZeroSrcTaskV0(CartesianProductEdgeManagerConfig config) {
+    when(mockContext.getSourceVertexName()).thenReturn("v0");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(2);
+    edgeManager.initialize(config);
+
+    assertEquals(0, edgeManager.getNumDestinationConsumerTasks(0));
+    assertEquals(0, edgeManager.getNumDestinationConsumerTasks(1));
+  }
+
+  private void testZeroSrcTaskV1(CartesianProductEdgeManagerConfig config) {
+    when(mockContext.getSourceVertexName()).thenReturn("v1");
+    when(mockContext.getSourceVertexNumTasks()).thenReturn(0);
+    edgeManager.initialize(config);
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
new file mode 100644
index 0000000..755c578
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TestCartesianProductVertexManager {
+  @Test(timeout = 5000)
+  public void testInitialize() throws Exception {
+    VertexManagerPluginContext context = mock(VertexManagerPluginContext.class);
+    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(context);
+    TezConfiguration conf = new TezConfiguration();
+
+    // partitioned case
+    CartesianProductConfig config =
+      new CartesianProductConfig(new int[]{2,3}, new String[]{"v0", "v1"}, null);
+    when(context.getUserPayload()).thenReturn(config.toUserPayload(conf));
+    EdgeProperty edgeProperty =
+      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+        CartesianProductEdgeManager.class.getName()), null, null, null, null);
+    Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
+    edgePropertyMap.put("v0", edgeProperty);
+    edgePropertyMap.put("v1", edgeProperty);
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
+    vertexManager.initialize();
+    assertTrue(vertexManager.getVertexManagerReal()
+      instanceof CartesianProductVertexManagerPartitioned);
+
+    // unpartitioned case
+    List<String> sourceVertices = new ArrayList<>();
+    sourceVertices.add("v0");
+    sourceVertices.add("v1");
+    config = new CartesianProductConfig(sourceVertices);
+    when(context.getUserPayload()).thenReturn(config.toUserPayload(conf));
+    vertexManager.initialize();
+    assertTrue(vertexManager.getVertexManagerReal()
+      instanceof CartesianProductVertexManagerUnpartitioned);
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
new file mode 100644
index 0000000..9aca647
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
@@ -0,0 +1,230 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.VertexLocationHint;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.dag.records.TaskAttemptIdentifierImpl;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Matchers;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Matchers.isNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestCartesianProductVertexManagerPartitioned {
+  @Captor
+  private ArgumentCaptor<Map<String, EdgeProperty>> edgePropertiesCaptor;
+  @Captor
+  private ArgumentCaptor<List<ScheduleTaskRequest>> scheduleTaskRequestCaptor;
+  private TezConfiguration conf = new TezConfiguration();
+
+  @Before
+  public void init() {
+    MockitoAnnotations.initMocks(this);
+  }
+
+  public static class TestFilter extends CartesianProductFilter {
+    public TestFilter(UserPayload payload) {
+      super(payload);
+    }
+
+    @Override
+    public boolean isValidCombination(Map<String, Integer> vertexPartitionMap) {
+      return vertexPartitionMap.get("v0") > vertexPartitionMap.get("v1");
+    }
+  }
+
+  private void testReconfigureVertexHelper(CartesianProductConfig config, int parallelism)
+    throws Exception {
+    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
+    when(mockContext.getUserPayload()).thenReturn(config.toUserPayload(conf));
+
+    EdgeProperty edgeProperty =
+      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+        CartesianProductEdgeManager.class.getName()), null, null, null, null);
+    Map<String, EdgeProperty> inputEdgeProperties = new HashMap<>();
+    for (String vertex : config.getSourceVertices()) {
+      inputEdgeProperties.put(vertex, edgeProperty);
+    }
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(inputEdgeProperties);
+    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(mockContext);
+    vertexManager.initialize();
+    ArgumentCaptor<Integer> parallelismCaptor = ArgumentCaptor.forClass(Integer.class);
+
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    verify(mockContext, times(1)).reconfigureVertex(parallelismCaptor.capture(),
+      isNull(VertexLocationHint.class), edgePropertiesCaptor.capture());
+    assertEquals((int)parallelismCaptor.getValue(), parallelism);
+    assertNull(edgePropertiesCaptor.getValue());
+  }
+
+  @Test(timeout = 5000)
+  public void testReconfigureVertex() throws Exception {
+    testReconfigureVertexHelper(
+      new CartesianProductConfig(new int[]{5,5}, new String[]{"v0", "v1"},
+        new CartesianProductFilterDescriptor(TestFilter.class.getName())), 10);
+    testReconfigureVertexHelper(
+      new CartesianProductConfig(new int[]{5,5}, new String[]{"v0", "v1"}, null), 25);
+  }
+
+  @Test(timeout = 5000)
+  public void testScheduling() throws Exception {
+    CartesianProductConfig config = new CartesianProductConfig(new int[]{2,2},
+      new String[]{"v0", "v1"}, null);
+    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
+    when(mockContext.getUserPayload()).thenReturn(config.toUserPayload(conf));
+    Set<String> inputVertices = new HashSet<String>();
+    inputVertices.add("v0");
+    inputVertices.add("v1");
+    when(mockContext.getVertexInputNames()).thenReturn(inputVertices);
+    when(mockContext.getVertexNumTasks("v0")).thenReturn(4);
+    when(mockContext.getVertexNumTasks("v1")).thenReturn(4);
+    EdgeProperty edgeProperty =
+      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+        CartesianProductEdgeManager.class.getName()), null, null, null, null);
+    Map<String, EdgeProperty> inputEdgeProperties = new HashMap<String, EdgeProperty>();
+    inputEdgeProperties.put("v0", edgeProperty);
+    inputEdgeProperties.put("v1", edgeProperty);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(inputEdgeProperties);
+    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(mockContext);
+    vertexManager.initialize();
+
+    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+
+
+    TaskAttemptIdentifier taId = mock(TaskAttemptIdentifier.class, Mockito.RETURNS_DEEP_STUBS);
+    when(taId.getTaskIdentifier().getVertexIdentifier().getName()).thenReturn("v0", "v0", "v1",
+      "v1", "v0", "v0", "v1", "v1");
+    when(taId.getTaskIdentifier().getIdentifier()).thenReturn(0, 1, 0, 1, 2, 3, 2, 3);
+
+    for (int i = 0; i < 2; i++) {
+      vertexManager.onSourceTaskCompleted(taId);
+      verify(mockContext, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    }
+
+    List<ScheduleTaskRequest> scheduleTaskRequests;
+
+    vertexManager.onSourceTaskCompleted(taId);
+    verify(mockContext, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
+    assertEquals(1, scheduleTaskRequests.size());
+    assertEquals(0, scheduleTaskRequests.get(0).getTaskIndex());
+
+    vertexManager.onSourceTaskCompleted(taId);
+    verify(mockContext, times(2)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
+    assertEquals(1, scheduleTaskRequests.size());
+    assertEquals(1, scheduleTaskRequests.get(0).getTaskIndex());
+
+    vertexManager.onSourceTaskCompleted(taId);
+    verify(mockContext, times(3)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
+    assertEquals(1, scheduleTaskRequests.size());
+    assertEquals(2, scheduleTaskRequests.get(0).getTaskIndex());
+
+    vertexManager.onSourceTaskCompleted(taId);
+    verify(mockContext, times(4)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
+    assertEquals(1, scheduleTaskRequests.size());
+    assertEquals(3, scheduleTaskRequests.get(0).getTaskIndex());
+
+    for (int i = 0; i < 2; i++) {
+      vertexManager.onSourceTaskCompleted(taId);
+      verify(mockContext, times(4)).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testVertexStartWithCompletion() throws Exception {
+    CartesianProductConfig config = new CartesianProductConfig(new int[]{2,2},
+      new String[]{"v0", "v1"}, null);
+    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
+    when(mockContext.getUserPayload()).thenReturn(config.toUserPayload(conf));
+    Set<String> inputVertices = new HashSet<String>();
+    inputVertices.add("v0");
+    inputVertices.add("v1");
+    when(mockContext.getVertexInputNames()).thenReturn(inputVertices);
+    when(mockContext.getVertexNumTasks("v0")).thenReturn(4);
+    when(mockContext.getVertexNumTasks("v1")).thenReturn(4);
+    EdgeProperty edgeProperty =
+      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+        CartesianProductEdgeManager.class.getName()), null, null, null, null);
+    Map<String, EdgeProperty> inputEdgeProperties = new HashMap<String, EdgeProperty>();
+    inputEdgeProperties.put("v0", edgeProperty);
+    inputEdgeProperties.put("v1", edgeProperty);
+    when(mockContext.getInputVertexEdgeProperties()).thenReturn(inputEdgeProperties);
+    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(mockContext);
+    vertexManager.initialize();
+
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+
+    List<TaskAttemptIdentifier> completions = new ArrayList<>();
+    TezDAGID dagId = TezDAGID.getInstance(ApplicationId.newInstance(0, 0), 0);
+    TezVertexID v0Id = TezVertexID.getInstance(dagId, 0);
+    TezVertexID v1Id = TezVertexID.getInstance(dagId, 1);
+
+    completions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(v0Id, 0), 0)));
+    completions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(v0Id, 1), 0)));
+    completions.add(new TaskAttemptIdentifierImpl("dag", "v1",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(v1Id, 0), 0)));
+
+    vertexManager.onVertexStarted(completions);
+
+    List<ScheduleTaskRequest> scheduleTaskRequests;
+    verify(mockContext, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
+    assertEquals(1, scheduleTaskRequests.size());
+    assertEquals(0, scheduleTaskRequests.get(0).getTaskIndex());
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
new file mode 100644
index 0000000..f76de96
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
@@ -0,0 +1,194 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.cartesianproduct;
+
+import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.VertexLocationHint;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
+import org.apache.tez.dag.api.event.VertexState;
+import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.dag.records.TaskAttemptIdentifierImpl;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Matchers;
+import org.mockito.MockitoAnnotations;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyMapOf;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Matchers.isNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestCartesianProductVertexManagerUnpartitioned {
+  @Captor
+  private ArgumentCaptor<Map<String, EdgeProperty>> edgePropertiesCaptor;
+  @Captor
+  private ArgumentCaptor<List<ScheduleTaskRequest>> scheduleTaskRequestCaptor;
+  private CartesianProductVertexManagerUnpartitioned vertexManager;
+  private VertexManagerPluginContext context;
+  private List<TaskAttemptIdentifier> allCompletions;
+
+  @Before
+  public void setup() throws Exception {
+    MockitoAnnotations.initMocks(this);
+    context = mock(VertexManagerPluginContext.class);
+    vertexManager = new CartesianProductVertexManagerUnpartitioned(context);
+    when(context.getVertexNumTasks(eq("v0"))).thenReturn(2);
+    when(context.getVertexNumTasks(eq("v1"))).thenReturn(3);
+
+    CartesianProductVertexManagerConfig config =
+      new CartesianProductVertexManagerConfig(false, new String[]{"v0","v1"}, null, 0, 0, null);
+    vertexManager.initialize(config);
+    allCompletions = new ArrayList<>();
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 0), 0), 0)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 0), 0), 1)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v1",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 1), 0), 0)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v1",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 1), 0), 1)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v1",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 1), 0), 2)));
+  }
+
+  @Test(timeout = 5000)
+  public void testReconfigureVertex() throws Exception {
+    ArgumentCaptor<Integer> parallelismCaptor = ArgumentCaptor.forClass(Integer.class);
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    verify(context, never()).reconfigureVertex(
+      anyInt(), any(VertexLocationHint.class), anyMapOf(String.class, EdgeProperty.class));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    verify(context, times(1)).reconfigureVertex(parallelismCaptor.capture(),
+      isNull(VertexLocationHint.class), edgePropertiesCaptor.capture());
+    assertEquals(6, (int)parallelismCaptor.getValue());
+    Map<String, EdgeProperty> edgeProperties = edgePropertiesCaptor.getValue();
+    for (EdgeProperty edgeProperty : edgeProperties.values()) {
+      UserPayload payload = edgeProperty.getEdgeManagerDescriptor().getUserPayload();
+      CartesianProductEdgeManagerConfig newConfig =
+        CartesianProductEdgeManagerConfig.fromUserPayload(payload);
+      assertArrayEquals(new int[]{2,3}, newConfig.getNumTasks());
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testCompletionAfterReconfigured() throws Exception {
+    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    vertexManager.onSourceTaskCompleted(allCompletions.get(0));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    vertexManager.onSourceTaskCompleted(allCompletions.get(2));
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
+    assertNotNull(requests);
+    assertEquals(1, requests.size());
+    assertEquals(0, requests.get(0).getTaskIndex());
+  }
+
+  @Test(timeout = 5000)
+  public void testCompletionBeforeReconfigured() throws Exception {
+    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.onSourceTaskCompleted(allCompletions.get(0));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    vertexManager.onSourceTaskCompleted(allCompletions.get(2));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
+    assertNotNull(requests);
+    assertEquals(1, requests.size());
+    assertEquals(0, requests.get(0).getTaskIndex());
+  }
+
+  @Test(timeout = 5000)
+  public void testStartAfterReconfigured() throws Exception {
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+
+    List<TaskAttemptIdentifier> completion = new ArrayList<>();
+    completion.add(allCompletions.get(0));
+    completion.add(allCompletions.get(2));
+    vertexManager.onVertexStarted(completion);
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
+    assertNotNull(requests);
+    assertEquals(1, requests.size());
+    assertEquals(0, requests.get(0).getTaskIndex());
+  }
+
+  @Test(timeout = 5000)
+  public void testStartBeforeReconfigured() throws Exception {
+    vertexManager.onVertexStarted(allCompletions);
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+  }
+
+  @Test(timeout = 5000)
+  public void testZeroSrcTask() throws Exception {
+    context = mock(VertexManagerPluginContext.class);
+    vertexManager = new CartesianProductVertexManagerUnpartitioned(context);
+    when(context.getVertexNumTasks(eq("v0"))).thenReturn(2);
+    when(context.getVertexNumTasks(eq("v1"))).thenReturn(0);
+
+    CartesianProductVertexManagerConfig config =
+      new CartesianProductVertexManagerConfig(false, new String[]{"v0","v1"}, null, 0, 0, null);
+    vertexManager.initialize(config);
+    allCompletions = new ArrayList<>();
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 0), 0), 0)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 0), 0), 1)));
+
+    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    vertexManager.onSourceTaskCompleted(allCompletions.get(0));
+    vertexManager.onSourceTaskCompleted(allCompletions.get(1));
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/tez/blob/1a068b23/tez-tests/src/test/java/org/apache/tez/test/TestFaultTolerance.java
----------------------------------------------------------------------
diff --git a/tez-tests/src/test/java/org/apache/tez/test/TestFaultTolerance.java b/tez-tests/src/test/java/org/apache/tez/test/TestFaultTolerance.java
index 2d10f94..764ef0f 100644
--- a/tez-tests/src/test/java/org/apache/tez/test/TestFaultTolerance.java
+++ b/tez-tests/src/test/java/org/apache/tez/test/TestFaultTolerance.java
@@ -19,8 +19,17 @@
 package org.apache.tez.test;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Random;
 
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
+import org.apache.tez.dag.api.ProcessorDescriptor;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
+import org.apache.tez.runtime.library.cartesianproduct.CartesianProductConfig;
+import org.apache.tez.runtime.library.cartesianproduct.CartesianProductEdgeManager;
+import org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManager;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -778,5 +787,68 @@ public class TestFaultTolerance {
     // dag will fail with 2 attempts failing from vertex v1
     runDAGAndVerify(dag, DAGStatus.State.FAILED, 2, "no progress");
   }
-  
+
+  /**
+   * In unpartitioned cartesian product, failure fraction should be #unique failure/#consumer that
+   * depends on the src task. Here we test a 2x2 cartesian product and let 4th destination task fail.
+   * The failure fraction limit is configured to be 0.25. So the failure fraction should be 1/2,
+   * not 1/4.
+   * @throws Exception
+   */
+  @Test
+  public void testCartesianProduct() throws Exception {
+    Configuration dagConf = new Configuration();
+    dagConf.setDouble(TezConfiguration.TEZ_TASK_MAX_ALLOWED_OUTPUT_FAILURES_FRACTION, 0.25);
+    DAG dag = DAG.create("dag");
+
+    Configuration vertexConf = new Configuration();
+    vertexConf.setInt(TestProcessor.getVertexConfName(
+      TestProcessor.TEZ_FAILING_PROCESSOR_VERIFY_TASK_INDEX, "v3"), 3);
+    vertexConf.setInt(TestProcessor.getVertexConfName(
+      TestProcessor.TEZ_FAILING_PROCESSOR_VERIFY_VALUE, "v3"), 5);
+    UserPayload vertexPayload = TezUtils.createUserPayloadFromConf(vertexConf);
+    ProcessorDescriptor processorDescriptor =
+      ProcessorDescriptor.create(TestProcessor.class.getName()).setUserPayload(vertexPayload);
+    Vertex v1 = Vertex.create("v1", processorDescriptor, 2);
+    Vertex v2 = Vertex.create("v2", processorDescriptor, 2);
+    Vertex v3 = Vertex.create("v3", processorDescriptor);
+
+    String[] sourceVertices = {"v1", "v2"};
+    CartesianProductConfig cartesianProductConfig =
+      new CartesianProductConfig(Arrays.asList(sourceVertices));
+    UserPayload cartesianProductPayload =
+      cartesianProductConfig.toUserPayload(new TezConfiguration());
+
+    v3.setVertexManagerPlugin(
+      VertexManagerPluginDescriptor.create(CartesianProductVertexManager.class.getName())
+        .setUserPayload(cartesianProductPayload));
+
+    EdgeManagerPluginDescriptor edgeManagerPluginDescriptor =
+      EdgeManagerPluginDescriptor.create(CartesianProductEdgeManager.class.getName())
+        .setUserPayload(cartesianProductPayload);
+
+    Configuration inputConf = new Configuration();
+    inputConf.setBoolean(TestInput.getVertexConfName(
+      TestInput.TEZ_FAILING_INPUT_DO_FAIL, "v3"), true);
+    inputConf.setInt(TestInput.getVertexConfName(
+      TestInput.TEZ_FAILING_INPUT_FAILING_TASK_INDEX, "v3"), 3);
+    inputConf.setInt(TestInput.getVertexConfName(
+      TestInput.TEZ_FAILING_INPUT_FAILING_TASK_ATTEMPT, "v3"), 0);
+    inputConf.setInt(TestInput.getVertexConfName(
+      TestInput.TEZ_FAILING_INPUT_FAILING_INPUT_INDEX, "v3"), 0);
+    inputConf.setInt(TestInput.getVertexConfName(
+      TestInput.TEZ_FAILING_INPUT_FAILING_UPTO_INPUT_ATTEMPT, "v3"), 0);
+    UserPayload inputPayload = TezUtils.createUserPayloadFromConf(inputConf);
+    EdgeProperty edgeProperty =
+      EdgeProperty.create(edgeManagerPluginDescriptor, DataMovementType.CUSTOM,
+        DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, TestOutput.getOutputDesc(null),
+        TestInput.getInputDesc(inputPayload));
+    Edge e1 = Edge.create(v1, v3, edgeProperty);
+    Edge e2 = Edge.create(v2, v3, edgeProperty);
+    dag.addVertex(v1).addVertex(v2).addVertex(v3);
+    dag.addEdge(e1).addEdge(e2);
+
+    // run dag
+    runDAGAndVerify(dag, DAGStatus.State.SUCCEEDED);
+  }
 }


Mime
View raw message