tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bi...@apache.org
Subject [05/50] [abbrv] git commit: TEZ-1494. DAG hangs waiting for ShuffleManager.getNextInput() (Rajesh Balamohan)
Date Mon, 29 Sep 2014 00:35:11 GMT
TEZ-1494. DAG hangs waiting for ShuffleManager.getNextInput() (Rajesh Balamohan)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/fcc74261
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/fcc74261
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/fcc74261

Branch: refs/heads/branch-0.5
Commit: fcc74261a0d76d00580f4d0dca042a1ed014ccec
Parents: bf6ac4e
Author: Rajesh Balamohan <rbalamohan@apache.org>
Authored: Fri Sep 12 04:09:21 2014 +0530
Committer: Rajesh Balamohan <rbalamohan@apache.org>
Committed: Fri Sep 12 04:09:21 2014 +0530

----------------------------------------------------------------------
 .../dag/impl/ImmediateStartVertexManager.java   |  92 ++++-
 .../app/dag/impl/RootInputVertexManager.java    |  27 +-
 .../tez/dag/app/dag/impl/TestVertexImpl.java    | 364 +++++++++++++++++--
 3 files changed, 417 insertions(+), 66 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/fcc74261/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ImmediateStartVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ImmediateStartVertexManager.java
b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ImmediateStartVertexManager.java
index b202d70..ac2b851 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ImmediateStartVertexManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/ImmediateStartVertexManager.java
@@ -18,9 +18,11 @@
 
 package org.apache.tez.dag.app.dag.impl;
 
-import java.util.List;
-import java.util.Map;
-
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
@@ -28,29 +30,101 @@ import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
 import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
 
-import com.google.common.collect.Lists;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Starts all tasks immediately on vertex start
  */
 public class ImmediateStartVertexManager extends VertexManagerPlugin {
 
+  private static final Log LOG = LogFactory.getLog(ImmediateStartVertexManager.class);
+
+  private final Map<String, SourceVertexInfo> srcVertexInfo = Maps.newHashMap();
+  private int managedTasks;
+  private boolean tasksScheduled = false;
+
+  class SourceVertexInfo {
+    EdgeProperty edgeProperty;
+    int numFinishedTasks;
+
+    SourceVertexInfo(EdgeProperty edgeProperty) {
+      this.edgeProperty = edgeProperty;
+    }
+  }
+
   public ImmediateStartVertexManager(VertexManagerPluginContext context) {
     super(context);
   }
 
   @Override
   public void onVertexStarted(Map<String, List<Integer>> completions) {
-    int numTasks = getContext().getVertexNumTasks(getContext().getVertexName());
-    List<TaskWithLocationHint> scheduledTasks = Lists.newArrayListWithCapacity(numTasks);
-    for (int i=0; i<numTasks; ++i) {
-      scheduledTasks.add(new TaskWithLocationHint(new Integer(i), null));
+    managedTasks = getContext().getVertexNumTasks(getContext().getVertexName());
+    Map<String, EdgeProperty> edges = getContext().getInputVertexEdgeProperties();
+    for (Map.Entry<String, EdgeProperty> entry : edges.entrySet()) {
+      String srcVertex = entry.getKey();
+      EdgeProperty edgeProp = entry.getValue();
+      srcVertexInfo.put(srcVertex, new SourceVertexInfo(edgeProp));
+    }
+
+    //handle completions
+    for (Map.Entry<String, List<Integer>> entry : completions.entrySet()) {
+      for (Integer task : entry.getValue()) {
+        handleSourceTaskFinished(entry.getKey(), task);
+      }
+    }
+    scheduleTasks();
+  }
+
+  private void handleSourceTaskFinished(String vertex, Integer taskId) {
+    SourceVertexInfo srcInfo = srcVertexInfo.get(vertex);
+    //Not mandatory to check for duplicate completions here
+    srcInfo.numFinishedTasks++;
+  }
+
+  private void scheduleTasks() {
+    if (!canScheduleTasks()) {
+      return;
+    }
+
+    List<TaskWithLocationHint> tasksToStart = Lists.newArrayListWithCapacity(managedTasks);
+    for (int i = 0; i < managedTasks; ++i) {
+      tasksToStart.add(new TaskWithLocationHint(new Integer(i), null));
+    }
+
+    if (!tasksToStart.isEmpty()) {
+      LOG.info("Starting " + tasksToStart.size() + " in " + getContext().getVertexName());
+      getContext().scheduleVertexTasks(tasksToStart);
     }
-    getContext().scheduleVertexTasks(scheduledTasks);
+    tasksScheduled = true;
+  }
+
+  private boolean canScheduleTasks() {
+    //Check if at least 1 task is finished from each source vertex (in case of broadcast
&
+    // one-to-one or custom)
+    for (Map.Entry<String, SourceVertexInfo> entry : srcVertexInfo.entrySet()) {
+      SourceVertexInfo srcVertexInfo = entry.getValue();
+      switch(srcVertexInfo.edgeProperty.getDataMovementType()) {
+      case ONE_TO_ONE:
+      case BROADCAST:
+      case CUSTOM:
+        if (srcVertexInfo.numFinishedTasks == 0) {
+          //do not schedule tasks until a task from source task is complete
+          return false;
+        }
+      default:
+        break;
+      }
+    }
+    return true;
   }
 
   @Override
   public void onSourceTaskCompleted(String srcVertexName, Integer attemptId) {
+    handleSourceTaskFinished(srcVertexName, attemptId);
+    if (!tasksScheduled) {
+      scheduleTasks();
+    }
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/tez/blob/fcc74261/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
index e6ffdc5..e850286 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
@@ -18,27 +18,23 @@
 
 package org.apache.tez.dag.app.dag.impl;
 
-import java.nio.ByteBuffer;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.UserPayload;
-import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
-import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
 import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.InputSpecUpdate;
 import org.apache.tez.runtime.api.events.InputConfigureVertexTasksEvent;
 import org.apache.tez.runtime.api.events.InputDataInformationEvent;
 import org.apache.tez.runtime.api.events.InputUpdatePayloadEvent;
-import org.apache.tez.runtime.api.events.VertexManagerEvent;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 
-public class RootInputVertexManager extends VertexManagerPlugin {
+public class RootInputVertexManager extends ImmediateStartVertexManager {
 
   private String configuredInputName;
 
@@ -46,27 +42,6 @@ public class RootInputVertexManager extends VertexManagerPlugin {
     super(context);
   }
 
-  @Override
-  public void initialize() {
-  }
-
-  @Override
-  public void onVertexStarted(Map<String, List<Integer>> completions) {
-    int numTasks = getContext().getVertexNumTasks(getContext().getVertexName());
-    List<TaskWithLocationHint> scheduledTasks = Lists.newArrayListWithCapacity(numTasks);
-    for (int i=0; i<numTasks; ++i) {
-      scheduledTasks.add(new TaskWithLocationHint(new Integer(i), null));
-    }
-    getContext().scheduleVertexTasks(scheduledTasks);
-  }
-
-  @Override
-  public void onSourceTaskCompleted(String srcVertexName, Integer attemptId) {
-  }
-
-  @Override
-  public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
-  }
 
   @Override
   public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor,

http://git-wip-us.apache.org/repos/asf/tez/blob/fcc74261/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
index d894928..04e2219 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
@@ -20,6 +20,7 @@ package org.apache.tez.dag.app.dag.impl;
 
 import java.nio.ByteBuffer;
 
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Mockito.doReturn;
@@ -1879,39 +1880,39 @@ public class TestVertexImpl {
     Assert.assertEquals(2, v3.getOutputVerticesCount());
     Assert.assertEquals(2, v3.getOutputVerticesCount());
 
-    Assert.assertTrue("vertex1".equals(v3.getInputSpecList(0).get(0)
+    assertTrue("vertex1".equals(v3.getInputSpecList(0).get(0)
         .getSourceVertexName())
         || "vertex2".equals(v3.getInputSpecList(0).get(0)
-            .getSourceVertexName()));
-    Assert.assertTrue("vertex1".equals(v3.getInputSpecList(0).get(1)
+        .getSourceVertexName()));
+    assertTrue("vertex1".equals(v3.getInputSpecList(0).get(1)
         .getSourceVertexName())
         || "vertex2".equals(v3.getInputSpecList(0).get(1)
-            .getSourceVertexName()));
-    Assert.assertTrue("i3_v1".equals(v3.getInputSpecList(0).get(0)
+        .getSourceVertexName()));
+    assertTrue("i3_v1".equals(v3.getInputSpecList(0).get(0)
         .getInputDescriptor().getClassName())
         || "i3_v2".equals(v3.getInputSpecList(0).get(0)
-            .getInputDescriptor().getClassName()));
-    Assert.assertTrue("i3_v1".equals(v3.getInputSpecList(0).get(1)
+        .getInputDescriptor().getClassName()));
+    assertTrue("i3_v1".equals(v3.getInputSpecList(0).get(1)
         .getInputDescriptor().getClassName())
         || "i3_v2".equals(v3.getInputSpecList(0).get(1)
-            .getInputDescriptor().getClassName()));
+        .getInputDescriptor().getClassName()));
 
-    Assert.assertTrue("vertex4".equals(v3.getOutputSpecList(0).get(0)
+    assertTrue("vertex4".equals(v3.getOutputSpecList(0).get(0)
         .getDestinationVertexName())
         || "vertex5".equals(v3.getOutputSpecList(0).get(0)
-            .getDestinationVertexName()));
-    Assert.assertTrue("vertex4".equals(v3.getOutputSpecList(0).get(1)
+        .getDestinationVertexName()));
+    assertTrue("vertex4".equals(v3.getOutputSpecList(0).get(1)
         .getDestinationVertexName())
         || "vertex5".equals(v3.getOutputSpecList(0).get(1)
-            .getDestinationVertexName()));
-    Assert.assertTrue("o3_v4".equals(v3.getOutputSpecList(0).get(0)
+        .getDestinationVertexName()));
+    assertTrue("o3_v4".equals(v3.getOutputSpecList(0).get(0)
         .getOutputDescriptor().getClassName())
         || "o3_v5".equals(v3.getOutputSpecList(0).get(0)
-            .getOutputDescriptor().getClassName()));
-    Assert.assertTrue("o3_v4".equals(v3.getOutputSpecList(0).get(1)
+        .getOutputDescriptor().getClassName()));
+    assertTrue("o3_v4".equals(v3.getOutputSpecList(0).get(1)
         .getOutputDescriptor().getClassName())
         || "o3_v5".equals(v3.getOutputSpecList(0).get(1)
-            .getOutputDescriptor().getClassName()));
+        .getOutputDescriptor().getClassName()));
   }
 
   @Test(timeout = 5000)
@@ -1940,13 +1941,13 @@ public class TestVertexImpl {
     Map<String, EdgeManagerPluginDescriptor> edgeManagerDescriptors =
         Collections.singletonMap(
        v1.getName(), mockEdgeManagerDescriptor);
-    Assert.assertTrue(v3.setParallelism(1, null, edgeManagerDescriptors, null));
-    Assert.assertTrue(v3.sourceVertices.get(v1).getEdgeManager() instanceof
+    assertTrue(v3.setParallelism(1, null, edgeManagerDescriptors, null));
+    assertTrue(v3.sourceVertices.get(v1).getEdgeManager() instanceof
         EdgeManagerForTest);
     Assert.assertEquals(1, v3.getTotalTasks());
     Assert.assertEquals(1, tasks.size());
     // the last one is removed
-    Assert.assertTrue(tasks.keySet().iterator().next().equals(firstTask));
+    assertTrue(tasks.keySet().iterator().next().equals(firstTask));
 
   }
   
@@ -1993,7 +1994,7 @@ public class TestVertexImpl {
     Edge edge = edges.get("e4");
     EdgeManagerPlugin em = edge.getEdgeManager();
     EdgeManagerForTest originalEm = (EdgeManagerForTest) em;
-    Assert.assertTrue(Arrays.equals(edgePayload, originalEm.getEdgeManagerContext()
+    assertTrue(Arrays.equals(edgePayload, originalEm.getEdgeManagerContext()
         .getUserPayload().deepCopyAsArray()));
 
     UserPayload userPayload = UserPayload.create(ByteBuffer.wrap(new String("foo").getBytes()));
@@ -2007,7 +2008,7 @@ public class TestVertexImpl {
 
     Map<String, EdgeManagerPluginDescriptor> edgeManagerDescriptors =
         Collections.singletonMap(v3.getName(), edgeManagerDescriptor);
-    Assert.assertTrue(v5.setParallelism(v5.getTotalTasks() - 1, null,
+    assertTrue(v5.setParallelism(v5.getTotalTasks() - 1, null,
         edgeManagerDescriptors, null)); // Must decrease.
 
     VertexImpl v5Impl = (VertexImpl) v5;
@@ -2015,10 +2016,10 @@ public class TestVertexImpl {
     EdgeManagerPlugin modifiedEdgeManager = v5Impl.sourceVertices.get(v3)
         .getEdgeManager();
     Assert.assertNotNull(modifiedEdgeManager);
-    Assert.assertTrue(modifiedEdgeManager instanceof EdgeManagerForTest);
+    assertTrue(modifiedEdgeManager instanceof EdgeManagerForTest);
 
     // Ensure initialize() is called with the correct payload
-    Assert.assertTrue(Arrays.equals(userPayload.deepCopyAsArray(),
+    assertTrue(Arrays.equals(userPayload.deepCopyAsArray(),
         ((EdgeManagerForTest) modifiedEdgeManager).getUserPayload().deepCopyAsArray()));
   }
 
@@ -2092,7 +2093,7 @@ public class TestVertexImpl {
     Assert.assertEquals(VertexTerminationCause.OWN_TASK_FAILURE, v.getTerminationCause());
     String diagnostics =
         StringUtils.join(v.getDiagnostics(), ",").toLowerCase();
-    Assert.assertTrue(diagnostics.contains("task failed"
+    assertTrue(diagnostics.contains("task failed"
         + ", taskid=" + t1.toString()));
   }
 
@@ -2104,7 +2105,7 @@ public class TestVertexImpl {
     String diagnostics =
         StringUtils.join(v2.getDiagnostics(), ",").toLowerCase();
     LOG.info("diagnostics v2: " + diagnostics);
-    Assert.assertTrue(diagnostics.contains(
+    assertTrue(diagnostics.contains(
         "vertex received kill in inited state"));
   }
 
@@ -2118,7 +2119,7 @@ public class TestVertexImpl {
     killVertex(v3);
     String diagnostics =
         StringUtils.join(v3.getDiagnostics(), ",").toLowerCase();
-    Assert.assertTrue(diagnostics.contains(
+    assertTrue(diagnostics.contains(
         "vertex received kill while in running state"));
   }
 
@@ -2200,7 +2201,7 @@ public class TestVertexImpl {
     Assert.assertNull(v2.getOutputCommitter("output"));
 
     VertexImpl v6 = vertices.get("vertex6");
-    Assert.assertTrue(v6.getOutputCommitter("outputx")
+    assertTrue(v6.getOutputCommitter("outputx")
         instanceof CountingOutputCommitter);
   }
 
@@ -2208,11 +2209,11 @@ public class TestVertexImpl {
   public void testVertexManagerInit() {
     initAllVertices(VertexState.INITED);
     VertexImpl v2 = vertices.get("vertex2");
-    Assert.assertTrue(v2.getVertexManager().getPlugin()
+    assertTrue(v2.getVertexManager().getPlugin()
         instanceof ImmediateStartVertexManager);
 
     VertexImpl v6 = vertices.get("vertex6");
-    Assert.assertTrue(v6.getVertexManager().getPlugin()
+    assertTrue(v6.getVertexManager().getPlugin()
         instanceof ShuffleVertexManager);
   }
 
@@ -3020,6 +3021,306 @@ public class TestVertexImpl {
     }
   }
 
+  @Test(timeout = 5000)
+  /**
+   * Ref: TEZ-1494
+   * If broadcast, one-to-one or custom edges are present in source, tasks should not start
until
+   * 1 task from each source vertex is complete.
+   */
+  public void testTaskSchedulingWithCustomEdges() {
+    setupPreDagCreation();
+    dagPlan = createCustomDAGWithCustomEdges();
+    setupPostDagCreation();
+
+    /**
+     *
+     *   M2 --(SG)--> R3 --(B)--\
+     *                           \
+     *   M7 --(B)---------------->M5 ---(SG)--> R6
+     *                            /
+     *   M8---(C)--------------->/
+     */
+
+    //init M2
+    VertexImpl m2 = vertices.get("M2");
+    VertexImpl m7 = vertices.get("M7");
+    VertexImpl r3 = vertices.get("R3");
+    VertexImpl m5 = vertices.get("M5");
+    VertexImpl m8 = vertices.get("M8");
+
+    initVertex(m2);
+    initVertex(m7);
+    initVertex(m8);
+    assertTrue(m7.getState().equals(VertexState.INITED));
+    assertTrue(m5.getState().equals(VertexState.INITED));
+    assertTrue(m8.getState().equals(VertexState.INITED));
+    assertTrue(m7.getVertexManager().getPlugin() instanceof ImmediateStartVertexManager);
+
+    //Start M2; Let tasks complete in M2; Also let 1 task complete in R3
+    dispatcher.getEventHandler().handle(new VertexEvent(m2.getVertexId(), VertexEventType.V_START));
+    dispatcher.await();
+    VertexEventTaskAttemptCompleted taskAttemptCompleted = new VertexEventTaskAttemptCompleted
+        (TezTaskAttemptID.getInstance(TezTaskID.getInstance(m2.getVertexId(),0), 0), TaskAttemptStateInternal.SUCCEEDED);
+    VertexEventTaskCompleted taskCompleted = new VertexEventTaskCompleted(TezTaskID.getInstance(m2
+        .getVertexId(), 0), TaskState.SUCCEEDED);
+    dispatcher.getEventHandler().handle(taskAttemptCompleted);
+    dispatcher.getEventHandler().handle(taskCompleted);
+    dispatcher.await();
+    taskAttemptCompleted = new VertexEventTaskAttemptCompleted
+        (TezTaskAttemptID.getInstance(TezTaskID.getInstance(r3.getVertexId(),0), 0),
+            TaskAttemptStateInternal.SUCCEEDED);
+    taskCompleted = new VertexEventTaskCompleted(TezTaskID.getInstance(r3
+        .getVertexId(), 0), TaskState.SUCCEEDED);
+    dispatcher.getEventHandler().handle(taskAttemptCompleted);
+    dispatcher.getEventHandler().handle(taskCompleted);
+    dispatcher.await();
+    assertTrue(m2.getState().equals(VertexState.SUCCEEDED));
+    assertTrue(m5.numSuccessSourceAttemptCompletions == 1);
+    assertTrue(m5.getState().equals(VertexState.INITED));
+
+    //R3 should be in running state as it has one task completed, and rest are pending
+    assertTrue(r3.getState().equals(VertexState.RUNNING));
+
+    //Let us start M7; M5 should start not start as it is dependent on M8 as well
+    dispatcher.getEventHandler().handle(new VertexEvent(m7.getVertexId(),VertexEventType.V_START));
+    dispatcher.await();
+    //Let one of the tasks get over in M7 as well.
+    taskAttemptCompleted = new VertexEventTaskAttemptCompleted
+        (TezTaskAttemptID.getInstance(TezTaskID.getInstance(m7.getVertexId(),0), 0),
+            TaskAttemptStateInternal.SUCCEEDED);
+    taskCompleted = new VertexEventTaskCompleted(TezTaskID.getInstance(m7
+        .getVertexId(), 0), TaskState.SUCCEEDED);
+    dispatcher.getEventHandler().handle(taskAttemptCompleted);
+    dispatcher.getEventHandler().handle(taskCompleted);
+    dispatcher.await();
+    assertTrue(m5.numSuccessSourceAttemptCompletions == 2);
+
+    //M5 should be in INITED state, as it depends on M8
+    assertTrue(m5.getState().equals(VertexState.INITED));
+    for(Task task : m5.getTasks().values()) {
+      assertTrue(task.getState().equals(TaskState.NEW));
+    }
+
+    //Let us start M8; M5 should start now
+    dispatcher.getEventHandler().handle(new VertexEvent(m8.getVertexId(),VertexEventType.V_START));
+    dispatcher.await();
+
+    //M5 in running state. But tasks should not be scheduled until M8 finishes a task.
+    assertTrue(m5.getState().equals(VertexState.RUNNING));
+    for(Task task : m5.getTasks().values()) {
+      assertTrue(task.getState().equals(TaskState.NEW));
+    }
+
+    //Let one of the tasks get over in M8 as well. This should trigger tasks to be scheduled
in M5
+    taskAttemptCompleted = new VertexEventTaskAttemptCompleted
+        (TezTaskAttemptID.getInstance(TezTaskID.getInstance(m8.getVertexId(),0), 0),
+            TaskAttemptStateInternal.SUCCEEDED);
+    taskCompleted = new VertexEventTaskCompleted(TezTaskID.getInstance(m8
+        .getVertexId(), 0), TaskState.SUCCEEDED);
+    dispatcher.getEventHandler().handle(taskAttemptCompleted);
+    dispatcher.getEventHandler().handle(taskCompleted);
+    dispatcher.await();
+
+    assertTrue(m5.numSuccessSourceAttemptCompletions == 3);
+    //Ensure all tasks in M5 are in scheduled state
+    for(Task task : m5.getTasks().values()) {
+      assertTrue(task.getState().equals(TaskState.SCHEDULED));
+    }
+  }
+
+  //For TEZ-1494
+  private DAGPlan createCustomDAGWithCustomEdges() {
+    /**
+     *
+     *   M2 --(SG)--> R3 --(B)--\
+     *                           \
+     *   M7 --(B)---------------->M5 ---(SG)--> R6
+     *                            /
+     *   M8---(C)--------------->/
+     */
+    DAGPlan dag = DAGPlan.newBuilder().setName("TestSamplerDAG")
+        .addVertex(VertexPlan.newBuilder()
+                .setName("M2")
+                .setProcessorDescriptor(
+                    TezEntityDescriptorProto.newBuilder().setClassName("M2.class"))
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder().addHost("host1").addRack("rack1").build())
+                .setTaskConfig(PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(1)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("M2.class")
+                        .build()
+                )
+                .addOutEdgeId("M2_R3")
+                .build()
+        )
+        .addVertex(VertexPlan.newBuilder()
+                .setName("M8")
+                .setProcessorDescriptor(
+                    TezEntityDescriptorProto.newBuilder().setClassName("M8.class"))
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder().addHost("host1").addRack("rack1").build())
+                .setTaskConfig(PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(1)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("M8.class")
+                        .build()
+                )
+                .addOutEdgeId("M8_M5")
+                .build()
+        )
+         .addVertex(VertexPlan.newBuilder()
+                 .setName("R3")
+                 .setProcessorDescriptor(
+                     TezEntityDescriptorProto.newBuilder().setClassName("M2.class"))
+                 .setType(PlanVertexType.NORMAL)
+                 .addTaskLocationHint(
+                     PlanTaskLocationHint.newBuilder().addHost("host2").addRack("rack1").build())
+                 .setTaskConfig(PlanTaskConfiguration.newBuilder()
+                         .setNumTasks(10)
+                         .setVirtualCores(4)
+                         .setMemoryMb(1024)
+                         .setJavaOpts("")
+                         .setTaskModule("R3.class")
+                         .build()
+                 )
+                 .addInEdgeId("M2_R3")
+                 .addOutEdgeId("R3_M5")
+                 .build()
+         )
+        .addVertex(VertexPlan.newBuilder()
+                .setName("M5")
+                .setProcessorDescriptor(
+                    TezEntityDescriptorProto.newBuilder().setClassName("M5.class"))
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder().addHost("host3").addRack("rack1").build())
+                .setTaskConfig(PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(10)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("M5.class")
+                        .build()
+                )
+                .addInEdgeId("R3_M5")
+                .addInEdgeId("M7_M5")
+                .addInEdgeId("M8_M5")
+                .addOutEdgeId("M5_R6")
+                .build()
+        )
+        .addVertex(VertexPlan.newBuilder()
+                .setName("M7")
+                .setProcessorDescriptor(
+                    TezEntityDescriptorProto.newBuilder().setClassName("M7.class"))
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder().addHost("host4").addRack("rack1").build())
+                .setTaskConfig(PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(10)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("M7.class")
+                        .build()
+                )
+                .addOutEdgeId("M7_M5")
+                .build()
+        )
+        .addVertex(VertexPlan.newBuilder()
+                .setName("R6")
+                .setProcessorDescriptor(
+                    TezEntityDescriptorProto.newBuilder().setClassName("R6.class"))
+                .setType(PlanVertexType.NORMAL)
+                .addTaskLocationHint(
+                    PlanTaskLocationHint.newBuilder().addHost("host3").addRack("rack1").build())
+                .setTaskConfig(PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(1)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("R6.class")
+                        .build()
+                )
+                .addInEdgeId("M5_R6")
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("M2_R3"))
+                .setInputVertexName("M2")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("M2_R3.class"))
+                .setOutputVertexName("R3")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("M2_R3")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("R3_M5"))
+                .setInputVertexName("R3")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("R3_M5.class"))
+                .setOutputVertexName("M5")
+                .setDataMovementType(PlanEdgeDataMovementType.BROADCAST)
+                .setId("R3_M5")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("M7_M5"))
+                .setInputVertexName("M7")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("M7_M5.class"))
+                .setOutputVertexName("M5")
+                .setDataMovementType(PlanEdgeDataMovementType.BROADCAST)
+                .setId("M7_M5")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("M5_R6"))
+                .setInputVertexName("M5")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("M5_R6.class"))
+                .setOutputVertexName("R6")
+                .setDataMovementType(PlanEdgeDataMovementType.SCATTER_GATHER)
+                .setId("M5_R6")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .addEdge(
+            EdgePlan.newBuilder()
+                .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("M8_M5"))
+                .setInputVertexName("M8")
+                .setEdgeSource(TezEntityDescriptorProto.newBuilder().setClassName("M8_M5.class"))
+                .setEdgeManager(
+                    TezEntityDescriptorProto.newBuilder()
+                        .setClassName(EdgeManagerForTest.class.getName())
+                        .setUserPayload(ByteString.copyFrom(edgePayload))
+                        .build())
+                .setOutputVertexName("M5")
+                .setDataMovementType(PlanEdgeDataMovementType.CUSTOM)
+                .setId("M8_M5")
+                .setDataSourceType(PlanEdgeDataSourceType.PERSISTED)
+                .setSchedulingType(PlanEdgeSchedulingType.SEQUENTIAL)
+                .build()
+        )
+        .build();
+
+    return dag;
+  }
+
   @SuppressWarnings("unchecked")
   @Test(timeout = 5000)
   public void testVertexWithInitializerSuccess() {
@@ -3215,6 +3516,7 @@ public class TestVertexImpl {
           dispatcher.getEventHandler(), taskAttemptListener,
           clock, thh, true, appContext, vertexLocationHint, null, taskSpecificLaunchCmdOption,
           updateTracker);
+      v.setInputVertices(new HashMap());
       vertexIdMap.put(vId, v);
       vertices.put(v.getName(), v);
       v.handle(new VertexEvent(vId, VertexEventType.V_INIT));
@@ -3454,8 +3756,8 @@ public class TestVertexImpl {
     List<GroupInputSpec> groupInSpec = vC.getGroupInputSpecList(0);
     Assert.assertEquals(1, groupInSpec.size());
     Assert.assertEquals("Group", groupInSpec.get(0).getGroupName());
-    Assert.assertTrue(groupInSpec.get(0).getGroupVertices().contains("A"));
-    Assert.assertTrue(groupInSpec.get(0).getGroupVertices().contains("B"));
+    assertTrue(groupInSpec.get(0).getGroupVertices().contains("A"));
+    assertTrue(groupInSpec.get(0).getGroupVertices().contains("B"));
     groupInSpec.get(0).getMergedInputDescriptor().getClassName().equals("Group.class");
   }
   


Mime
View raw message