tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ss...@apache.org
Subject git commit: TEZ-1169. Allow numPhysicalInputs to be specified for RootInputs. (sseth)
Date Tue, 10 Jun 2014 20:08:58 GMT
Repository: incubator-tez
Updated Branches:
  refs/heads/master d77a2255a -> d3fdd81bc


TEZ-1169. Allow numPhysicalInputs to be specified for RootInputs.
(sseth)


Project: http://git-wip-us.apache.org/repos/asf/incubator-tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-tez/commit/d3fdd81b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-tez/tree/d3fdd81b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-tez/diff/d3fdd81b

Branch: refs/heads/master
Commit: d3fdd81bce50e9e4ac3b0b9f7e773d79eb9ac2a1
Parents: d77a225
Author: Siddharth Seth <sseth@apache.org>
Authored: Tue Jun 10 13:08:30 2014 -0700
Committer: Siddharth Seth <sseth@apache.org>
Committed: Tue Jun 10 13:08:30 2014 -0700

----------------------------------------------------------------------
 .../tez/dag/api/VertexManagerPluginContext.java |  13 +-
 .../tez/runtime/api/RootInputSpecUpdate.java    | 101 ++++++++++
 .../RootInputConfigureVertexTasksEvent.java     |  13 +-
 .../java/org/apache/tez/dag/app/dag/Vertex.java |   5 +-
 .../app/dag/impl/RootInputVertexManager.java    |   9 +-
 .../apache/tez/dag/app/dag/impl/VertexImpl.java | 119 +++++++----
 .../tez/dag/app/dag/impl/VertexManager.java     |   9 +-
 .../events/VertexParallelismUpdatedEvent.java   |  42 +++-
 tez-dag/src/main/proto/HistoryEvents.proto      |   7 +
 .../tez/dag/app/dag/impl/TestVertexImpl.java    | 197 ++++++++++++++++++-
 .../TestHistoryEventsProtoConversion.java       |  27 ++-
 .../common/MRInputAMSplitGenerator.java         |   4 +-
 .../vertexmanager/ShuffleVertexManager.java     |   2 +-
 .../vertexmanager/TestShuffleVertexManager.java |   8 +-
 14 files changed, 485 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java b/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
index 70cb6d2..7c48adc 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/VertexManagerPluginContext.java
@@ -22,13 +22,15 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+
 import javax.annotation.Nullable;
 
 import org.apache.hadoop.classification.InterfaceStability.Unstable;
 import org.apache.hadoop.yarn.api.records.Container;
-import javax.annotation.Nullable;
+
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
 
 import com.google.common.base.Preconditions;
@@ -111,19 +113,22 @@ public interface VertexManagerPluginContext {
   int getNumClusterNodes();
   
   /**
-   * Set the new parallelism (number of tasks) of this vertex.
+   * Set the new parallelism (number of tasks) of this vertex,
    * Map of source (input) vertices and edge managers to change the event routing
-   * between the source tasks and the new destination tasks.
+   * between the source tasks and the new destination tasks and the number of physical inputs for root inputs.
    * This API can change the parallelism only once. Subsequent attempts will be 
    * disallowed
    * @param parallelism New number of tasks in the vertex
    * @param locationHint the placement policy for tasks.
    * @param sourceEdgeManagers Edge Managers to be updated
+   * @param rootInputSpecUpdate Updated Root Input specifications, if any.
+   *        If none specified, a default of 1 physical input is used
    * @return true if the operation was allowed.
    */
   public boolean setVertexParallelism(int parallelism,
       @Nullable VertexLocationHint locationHint,
-      @Nullable Map<String, EdgeManagerDescriptor> sourceEdgeManagers);
+      @Nullable Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+      @Nullable Map<String, RootInputSpecUpdate> rootInputSpecUpdate);
   
   /**
    * Allows a VertexManagerPlugin to assign Events for Root Inputs

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-api/src/main/java/org/apache/tez/runtime/api/RootInputSpecUpdate.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/runtime/api/RootInputSpecUpdate.java b/tez-api/src/main/java/org/apache/tez/runtime/api/RootInputSpecUpdate.java
new file mode 100644
index 0000000..72adf78
--- /dev/null
+++ b/tez-api/src/main/java/org/apache/tez/runtime/api/RootInputSpecUpdate.java
@@ -0,0 +1,101 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.api;
+
+import java.util.List;
+
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Update Input specs for Root Inputs running in a task. Allows setting the number of physical
+ * inputs for all work units if they have the same number of physical inputs, or individual
+ * numPhysicalInputs for each work unit.
+ * 
+ */
+public class RootInputSpecUpdate {
+
+  private final boolean forAllWorkUnits;
+  private final List<Integer> numPhysicalInputs;
+
+  private final static RootInputSpecUpdate DEFAULT_SINGLE_PHYSICAL_INPUT_SPEC = createAllTaskRootInputSpecUpdate(1);
+  
+  /**
+   * Create an update instance where all work units (typically represented by
+   * {@link RootInputDataInformationEvent}) will have the same number of physical inputs.
+   * 
+   * @param numPhysicalInputs
+   *          the number of physical inputs for all work units which will use the LogicalInput
+   * @return
+   */
+  public static RootInputSpecUpdate createAllTaskRootInputSpecUpdate(int numPhysicalInputs) {
+    return new RootInputSpecUpdate(numPhysicalInputs);
+  }
+
+  /**
+   * Create an update instance where all work units (typically represented by
+   * {@link RootInputDataInformationEvent}) will have the same number of physical inputs.
+   * 
+   * @param perWorkUnitNumPhysicalInputs
+   *          A list containing one entry per work unit. The order in the list corresponds to task
+   *          index or equivalently the order of RootInputDataInformationEvents being sent.
+   * @return
+   */
+  public static RootInputSpecUpdate createPerTaskRootInputSpecUpdate(
+      List<Integer> perWorkUnitNumPhysicalInputs) {
+    return new RootInputSpecUpdate(perWorkUnitNumPhysicalInputs);
+  }
+  
+  public static RootInputSpecUpdate getDefaultSinglePhysicalInputSpecUpdate() {
+    return DEFAULT_SINGLE_PHYSICAL_INPUT_SPEC;
+  }
+
+  private RootInputSpecUpdate(int numPhysicalInputs) {
+    this.forAllWorkUnits = true;
+    this.numPhysicalInputs = Lists.newArrayList(numPhysicalInputs);
+  }
+
+  private RootInputSpecUpdate(List<Integer> perWorkUnitNumPhysicalInputs) {
+    this.forAllWorkUnits = false;
+    this.numPhysicalInputs = Lists.newArrayList(perWorkUnitNumPhysicalInputs);
+  }
+
+  @Private
+  public int getNumPhysicalInputsForWorkUnit(int index) {
+    if (this.forAllWorkUnits) {
+      return numPhysicalInputs.get(0);
+    } else {
+      return numPhysicalInputs.get(index);
+    }
+  }
+  
+  @Private
+  /* Used for recovery serialization */
+  public boolean isForAllWorkUnits() {
+    return this.forAllWorkUnits;
+  }
+  
+  @Private
+  /* Used for recovery serialization */
+  public List<Integer> getAllNumPhysicalInputs() {
+    return numPhysicalInputs;
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-api/src/main/java/org/apache/tez/runtime/api/events/RootInputConfigureVertexTasksEvent.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/runtime/api/events/RootInputConfigureVertexTasksEvent.java b/tez-api/src/main/java/org/apache/tez/runtime/api/events/RootInputConfigureVertexTasksEvent.java
index 1eb7f14..d8c3cce 100644
--- a/tez-api/src/main/java/org/apache/tez/runtime/api/events/RootInputConfigureVertexTasksEvent.java
+++ b/tez-api/src/main/java/org/apache/tez/runtime/api/events/RootInputConfigureVertexTasksEvent.java
@@ -22,15 +22,19 @@ import java.util.List;
 
 import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
 import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 
 public class RootInputConfigureVertexTasksEvent extends Event {
 
   private final int numTasks;
   private final List<TaskLocationHint> taskLocationHints;
-  
-  public RootInputConfigureVertexTasksEvent(int numTasks, List<TaskLocationHint> locationHints) {
+  private final RootInputSpecUpdate rootInputSpecUpdate;
+
+  public RootInputConfigureVertexTasksEvent(int numTasks, List<TaskLocationHint> locationHints,
+      RootInputSpecUpdate rootInputSpecUpdate) {
     this.numTasks = numTasks;
     this.taskLocationHints = locationHints;
+    this.rootInputSpecUpdate = rootInputSpecUpdate;
   }
 
   public int getNumTasks() {
@@ -41,5 +45,8 @@ public class RootInputConfigureVertexTasksEvent extends Event {
     return taskLocationHints;
   }
 
-  
+  public RootInputSpecUpdate getRootInputSpecUpdate() {
+    return this.rootInputSpecUpdate;
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
index 93f047a..da65458 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/Vertex.java
@@ -23,6 +23,7 @@ import java.util.Map;
 import java.util.Set;
 
 import javax.annotation.Nullable;
+
 import org.apache.hadoop.yarn.api.records.Resource;
 import org.apache.tez.common.counters.TezCounters;
 import org.apache.tez.dag.api.EdgeManagerDescriptor;
@@ -44,6 +45,7 @@ import org.apache.tez.dag.history.HistoryEvent;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.OutputCommitter;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.impl.GroupInputSpec;
 import org.apache.tez.runtime.api.impl.InputSpec;
 import org.apache.tez.runtime.api.impl.OutputSpec;
@@ -83,7 +85,8 @@ public interface Vertex extends Comparable<Vertex> {
   TaskLocationHint getTaskLocationHint(TezTaskID taskID);
 
   boolean setParallelism(int parallelism, VertexLocationHint vertexLocationHint,
-      Map<String, EdgeManagerDescriptor> sourceEdgeManagers);
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+      Map<String, RootInputSpecUpdate> rootInputSpecUpdate);
   void setVertexLocationHint(VertexLocationHint vertexLocationHint);
 
   // CHANGE THESE TO LISTS AND MAINTAIN ORDER?

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
index cf68d5e..e1d73e4 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/RootInputVertexManager.java
@@ -18,6 +18,7 @@
 
 package org.apache.tez.dag.app.dag.impl;
 
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -27,6 +28,7 @@ import org.apache.tez.dag.api.VertexManagerPlugin;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
 import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
 import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
 import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
 import org.apache.tez.runtime.api.events.RootInputUpdatePayloadEvent;
@@ -75,8 +77,13 @@ public class RootInputVertexManager implements VertexManagerPlugin {
             .checkState(context.getVertexNumTasks(context.getVertexName()) == -1,
                 "Parallelism for the vertex should be set to -1 if the InputInitializer is setting parallelism");
         RootInputConfigureVertexTasksEvent cEvent = (RootInputConfigureVertexTasksEvent) event;
+        Map<String, RootInputSpecUpdate> rootInputSpecUpdate = new HashMap<String, RootInputSpecUpdate>();
+        rootInputSpecUpdate.put(
+            inputName,
+            cEvent.getRootInputSpecUpdate() == null ? RootInputSpecUpdate
+                .getDefaultSinglePhysicalInputSpecUpdate() : cEvent.getRootInputSpecUpdate());
         context.setVertexParallelism(cEvent.getNumTasks(),
-            new VertexLocationHint(cEvent.getTaskLocationHints()), null);
+            new VertexLocationHint(cEvent.getTaskLocationHints()), null, rootInputSpecUpdate);
       }
       if (event instanceof RootInputUpdatePayloadEvent) {
         // No tasks should have been started yet. Checked by initial state check.

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
index dcdbe31..bbe0ccb 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexImpl.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.EnumSet;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.LinkedList;
@@ -36,6 +37,7 @@ import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import javax.annotation.Nullable;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.classification.InterfaceAudience.Private;
@@ -132,6 +134,7 @@ import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.OutputCommitter;
 import org.apache.tez.runtime.api.OutputCommitterContext;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
 import org.apache.tez.runtime.api.events.DataMovementEvent;
 import org.apache.tez.runtime.api.events.InputFailedEvent;
@@ -546,10 +549,12 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   private Map<Vertex, Edge> targetVertices;
   Set<Edge> uninitializedEdges = Sets.newHashSet();
 
-  private Map<String, RootInputLeafOutputDescriptor<InputDescriptor>> additionalInputs;
+  private Map<String, RootInputLeafOutputDescriptor<InputDescriptor>> rootInputDescriptors;
   private Map<String, RootInputLeafOutputDescriptor<OutputDescriptor>> additionalOutputs;
   private Map<String, OutputCommitter> outputCommitters;
-  private final List<InputSpec> additionalInputSpecs = new ArrayList<InputSpec>();
+  private Map<String, RootInputSpecUpdate> rootInputSpecs = new HashMap<String, RootInputSpecUpdate>();
+  private static final RootInputSpecUpdate DEFAULT_ROOT_INPUT_SPECS = RootInputSpecUpdate
+      .getDefaultSinglePhysicalInputSpecUpdate(); 
   private final List<OutputSpec> additionalOutputSpecs = new ArrayList<OutputSpec>();
   private Set<String> inputsWithInitializers;
   private int numInitializedInputs;
@@ -588,6 +593,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   private boolean hasCommitter = false;
   private boolean vertexCompleteSeen = false;
   private Map<String,EdgeManagerDescriptor> recoveredSourceEdgeManagers = null;
+  private Map<String, RootInputSpecUpdate> recoveredRootInputSpecUpdates = null;
 
   // Recovery related flags
   boolean recoveryInitEventSeen = false;
@@ -950,21 +956,22 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   }
 
   private void handleParallelismUpdate(int newParallelism,
-      Map<String, EdgeManagerDescriptor> sourceEdgeManagers) {
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+      Map<String, RootInputSpecUpdate> rootInputSpecUpdates) {
     LinkedHashMap<TezTaskID, Task> currentTasks = this.tasks;
     Iterator<Map.Entry<TezTaskID, Task>> iter = currentTasks.entrySet()
         .iterator();
     int i = 0;
     while (iter.hasNext()) {
       i++;
-      Map.Entry<TezTaskID, Task> entry = iter.next();
+      iter.next();
       if (i <= newParallelism) {
         continue;
       }
       iter.remove();
     }
-    this.recoveredSourceEdgeManagers =
-        sourceEdgeManagers;
+    this.recoveredSourceEdgeManagers = sourceEdgeManagers;
+    this.recoveredRootInputSpecUpdates = rootInputSpecUpdates;
   }
 
   @Override
@@ -1003,7 +1010,8 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
           setTaskLocationHints(updatedEvent.getVertexLocationHint());
         }
         numTasks = updatedEvent.getNumTasks();
-        handleParallelismUpdate(numTasks, updatedEvent.getSourceEdgeManagers());
+        handleParallelismUpdate(numTasks, updatedEvent.getSourceEdgeManagers(),
+          updatedEvent.getRootInputSpecUpdates());
         if (LOG.isDebugEnabled()) {
           LOG.debug("Recovered state for vertex after parallelism updated event"
               + ", vertex=" + logIdentifier
@@ -1086,12 +1094,15 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
 
   @Override
   public boolean setParallelism(int parallelism, VertexLocationHint vertexLocationHint,
-      Map<String, EdgeManagerDescriptor> sourceEdgeManagers) {
-    return setParallelism(parallelism, vertexLocationHint, sourceEdgeManagers, false);
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+      Map<String, RootInputSpecUpdate> rootInputSpecUpdates) {
+    return setParallelism(parallelism, vertexLocationHint, sourceEdgeManagers, rootInputSpecUpdates,
+        false);
   }
 
   private boolean setParallelism(int parallelism, VertexLocationHint vertexLocationHint,
       Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+      Map<String, RootInputSpecUpdate> rootInputSpecUpdates,
       boolean recovering) {
     if (recovering) {
       writeLock.lock();
@@ -1114,6 +1125,13 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
             }
           }
         }
+        
+        // Restore any rootInputSpecUpdates which may have been registered during a parallelism
+        // update.
+        if (rootInputSpecUpdates != null) {
+          LOG.info("Got updated RootInputsSpecs during recovery: " + rootInputSpecUpdates.toString());
+          this.rootInputSpecs.putAll(rootInputSpecUpdates);
+        }
         return true;
       } finally {
         writeLock.unlock();
@@ -1156,6 +1174,21 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
             }
           }
         }
+        if (rootInputSpecUpdates != null) {
+          LOG.info("Got updated RootInputsSpecs: " + rootInputSpecUpdates.toString());
+          // Sanity check for correct number of updates.
+          for (Entry<String, RootInputSpecUpdate> rootInputSpecUpdateEntry : rootInputSpecUpdates
+              .entrySet()) {
+            Preconditions
+                .checkState(
+                    rootInputSpecUpdateEntry.getValue().isForAllWorkUnits()
+                        || (rootInputSpecUpdateEntry.getValue().getAllNumPhysicalInputs() != null && rootInputSpecUpdateEntry
+                            .getValue().getAllNumPhysicalInputs().size() == parallelism),
+                    "Not enough input spec updates for root input named "
+                        + rootInputSpecUpdateEntry.getKey());
+          }
+          this.rootInputSpecs.putAll(rootInputSpecUpdates);
+        }
         this.numTasks = parallelism;
         this.createTasks();
         LOG.info("Vertex " + getVertexId() + 
@@ -1164,6 +1197,13 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
           getEventHandler().handle(new VertexEventParallelismInitialized(getVertexId()));
         }
       } else {
+        // This is an artificial restriction since there's no way of knowing whether a VertexManager
+        // will attempt to update root input specs. When parallelism has not been initialized, the
+        // Vertex will not be in started state so it's safe to update the specifications.
+        // TODO TEZ-937 - add e mechanism to query vertex managers, or for VMs to indicate readines
+        // for a vertex to start.
+        Preconditions.checkState(rootInputSpecUpdates == null,
+            "Root Input specs can only be updated when the vertex is configured with -1 tasks");
         if (parallelism >= numTasks) {
           // not that hard to support perhaps. but checking right now since there
           // is no use case for it and checking may catch other bugs.
@@ -1174,15 +1214,14 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
         if (parallelism == numTasks) {
           LOG.info("setParallelism same as current value: " + parallelism + 
               " for vertex: " + logIdentifier);
-          Preconditions
-          .checkArgument(sourceEdgeManagers != null,
-              "Source edge managers must be set when not changing parallelism");
+          Preconditions.checkArgument(sourceEdgeManagers != null,
+              "Source edge managers or RootInputSpecs must be set when not changing parallelism");
         } else {
           LOG.info(
               "Resetting vertex location hints due to change in parallelism for vertex: " + logIdentifier);
           vertexLocationHint = null;
         }
-  
+
         // start buffering incoming events so that we can re-route existing events
         for (Edge edge : sourceVertices.values()) {
           edge.startEventBuffering();
@@ -1237,7 +1276,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
         VertexParallelismUpdatedEvent parallelismUpdatedEvent =
             new VertexParallelismUpdatedEvent(vertexId, numTasks,
                 vertexLocationHint,
-                sourceEdgeManagers);
+                sourceEdgeManagers, rootInputSpecUpdates);
         appContext.getHistoryHandler().handle(new DAGHistoryEvent(getDAGId(),
             parallelismUpdatedEvent));
 
@@ -1738,22 +1777,12 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
 
     // Check if any inputs need initializers
     if (event != null) {
-      this.additionalInputs = event.getAdditionalInputs();
-      if (additionalInputs != null) {
-      // FIXME References to descriptor kept in both objects
-        for (InputSpec inputSpec : this.additionalInputSpecs) {
-          if (additionalInputs.containsKey(inputSpec.getSourceVertexName())
-                && additionalInputs.get(inputSpec.getSourceVertexName()).getDescriptor() != null) {
-            inputSpec.setInputDescriptor(
-                additionalInputs.get(inputSpec.getSourceVertexName()).getDescriptor());
-          }
-        }
-      }
+      this.rootInputDescriptors = event.getAdditionalInputs();
     } else {
-      if (additionalInputs != null) {
+      if (rootInputDescriptors != null) {
         LOG.info("Root Inputs exist for Vertex: " + getName() + " : "
-            + additionalInputs);
-        for (RootInputLeafOutputDescriptor<InputDescriptor> input : additionalInputs.values()) {
+            + rootInputDescriptors);
+        for (RootInputLeafOutputDescriptor<InputDescriptor> input : rootInputDescriptors.values()) {
           if (input.getInitializerClassName() != null) {
             if (inputsWithInitializers == null) {
               inputsWithInitializers = Sets.newHashSet();
@@ -2314,7 +2343,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
             break;
           }
           if (!vertex.setParallelism(0,
-              null, vertex.recoveredSourceEdgeManagers, true)) {
+              null, vertex.recoveredSourceEdgeManagers, vertex.recoveredRootInputSpecUpdates, true)) {
             LOG.info("Failed to recover edge managers, vertex="
                 + vertex.logIdentifier);
             vertex.finished(VertexState.FAILED,
@@ -2363,7 +2392,8 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
             endState = VertexState.FAILED;
             break;
           }
-          if (!vertex.setParallelism(0, null, vertex.recoveredSourceEdgeManagers, true)) {
+          if (!vertex.setParallelism(0, null, vertex.recoveredSourceEdgeManagers,
+            vertex.recoveredRootInputSpecUpdates, true)) {
             LOG.info("Failed to recover edge managers");
             vertex.finished(VertexState.FAILED,
                 VertexTerminationCause.INIT_FAILURE);
@@ -2537,7 +2567,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
           List<RootInputLeafOutputDescriptor<InputDescriptor>> inputList = Lists
               .newArrayListWithCapacity(vertex.inputsWithInitializers.size());
           for (String inputName : vertex.inputsWithInitializers) {
-            inputList.add(vertex.additionalInputs.get(inputName));
+            inputList.add(vertex.rootInputDescriptors.get(inputName));
           }
           LOG.info("Vertex will initialize via inputInitializers "
               + vertex.logIdentifier + ". Starting root input initializers: "
@@ -2581,7 +2611,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
           List<RootInputLeafOutputDescriptor<InputDescriptor>> inputList = Lists
               .newArrayListWithCapacity(vertex.inputsWithInitializers.size());
           for (String inputName : vertex.inputsWithInitializers) {
-            inputList.add(vertex.additionalInputs.get(inputName));
+            inputList.add(vertex.rootInputDescriptors.get(inputName));
           }
           LOG.info("Starting root input initializers: "
               + vertex.inputsWithInitializers.size());
@@ -2718,7 +2748,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
           " sent by vertex " + splitEvent.getSenderVertex() +
           " numTasks " + splitEvent.getNumTasks());
       vertex.originalOneToOneSplitSource = originalSplitSource;
-      vertex.setParallelism(splitEvent.getNumTasks(), null, null);
+      vertex.setParallelism(splitEvent.getNumTasks(), null, null, null);
       if (vertex.getState() == VertexState.RUNNING || 
           vertex.getState() == VertexState.INITED) {
         return vertex.getState();
@@ -3397,20 +3427,18 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   public void setAdditionalInputs(List<RootInputLeafOutputProto> inputs) {
     Preconditions.checkArgument(inputs.size() < 2,
         "For now, only a single root input can be specified on a Vertex");
-    this.additionalInputs = Maps.newHashMapWithExpectedSize(inputs.size());
+    this.rootInputDescriptors = Maps.newHashMapWithExpectedSize(inputs.size());
     for (RootInputLeafOutputProto input : inputs) {
 
       InputDescriptor id = DagTypeConverters
           .convertInputDescriptorFromDAGPlan(input.getEntityDescriptor());
 
-      this.additionalInputs.put(input.getName(),
+      this.rootInputDescriptors.put(input.getName(),
           new RootInputLeafOutputDescriptor<InputDescriptor>(input.getName(), id,
               input.hasInitializerClassName() ? input.getInitializerClassName()
                   : null));
-      InputSpec inputSpec = new InputSpec(input.getName(), id, 0);
-      additionalInputSpecs.add(inputSpec);
+      this.rootInputSpecs.put(input.getName(), DEFAULT_ROOT_INPUT_SPECS);
     }
-
   }
 
   @Nullable
@@ -3451,7 +3479,7 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   @Nullable
   @Override
   public Map<String, RootInputLeafOutputDescriptor<InputDescriptor>> getAdditionalInputs() {
-    return this.additionalInputs;
+    return this.rootInputDescriptors;
   }
 
   @Nullable
@@ -3547,9 +3575,16 @@ public class VertexImpl implements org.apache.tez.dag.app.dag.Vertex,
   // TODO Eventually remove synchronization.
   @Override
   public synchronized List<InputSpec> getInputSpecList(int taskIndex) {
-    inputSpecList = new ArrayList<InputSpec>(
-        this.getInputVerticesCount() + additionalInputSpecs.size());
-    inputSpecList.addAll(additionalInputSpecs);
+    inputSpecList = new ArrayList<InputSpec>(this.getInputVerticesCount()
+        + (rootInputDescriptors == null ? 0 : rootInputDescriptors.size()));
+    if (rootInputDescriptors != null) {
+      for (Entry<String, RootInputLeafOutputDescriptor<InputDescriptor>> rootInputDescriptorEntry : rootInputDescriptors
+          .entrySet()) {
+        inputSpecList.add(new InputSpec(rootInputDescriptorEntry.getKey(),
+            rootInputDescriptorEntry.getValue().getDescriptor(), rootInputSpecs.get(
+                rootInputDescriptorEntry.getKey()).getNumPhysicalInputsForWorkUnit(taskIndex)));
+      }
+    }
     for (Entry<Vertex, Edge> entry : this.getInputVertices().entrySet()) {
       InputSpec inputSpec = entry.getValue().getDestinationSpec(taskIndex);
       if (LOG.isDebugEnabled()) {

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
index 7f8cc14..35c3943 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/VertexManager.java
@@ -19,6 +19,7 @@
 package org.apache.tez.dag.app.dag.impl;
 
 import static com.google.common.base.Preconditions.checkNotNull;
+
 import java.io.IOException;
 import java.util.Collection;
 import java.util.List;
@@ -26,6 +27,7 @@ import java.util.Map;
 import java.util.Set;
 
 import javax.annotation.Nullable;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.yarn.api.records.Container;
@@ -50,6 +52,7 @@ import org.apache.tez.dag.app.dag.event.VertexEventRouteEvent;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
 import org.apache.tez.runtime.api.impl.EventMetaData;
@@ -103,8 +106,10 @@ public class VertexManager {
 
     @Override
     public boolean setVertexParallelism(int parallelism, VertexLocationHint vertexLocationHint,
-        Map<String, EdgeManagerDescriptor> sourceEdgeManagers) {
-      return managedVertex.setParallelism(parallelism, vertexLocationHint, sourceEdgeManagers);
+        Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+        Map<String, RootInputSpecUpdate> rootInputSpecUpdate) {
+      return managedVertex.setParallelism(parallelism, vertexLocationHint, sourceEdgeManagers,
+          rootInputSpecUpdate);
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/main/java/org/apache/tez/dag/history/events/VertexParallelismUpdatedEvent.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/history/events/VertexParallelismUpdatedEvent.java b/tez-dag/src/main/java/org/apache/tez/dag/history/events/VertexParallelismUpdatedEvent.java
index 8860567..15da86d 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/history/events/VertexParallelismUpdatedEvent.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/history/events/VertexParallelismUpdatedEvent.java
@@ -32,7 +32,11 @@ import org.apache.tez.dag.history.HistoryEvent;
 import org.apache.tez.dag.history.HistoryEventType;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.dag.recovery.records.RecoveryProtos.EdgeManagerDescriptorProto;
+import org.apache.tez.dag.recovery.records.RecoveryProtos.RootInputSpecUpdateProto;
 import org.apache.tez.dag.recovery.records.RecoveryProtos.VertexParallelismUpdatedProto;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
+
+import com.google.common.collect.Maps;
 
 public class VertexParallelismUpdatedEvent implements HistoryEvent {
 
@@ -40,17 +44,20 @@ public class VertexParallelismUpdatedEvent implements HistoryEvent {
   private int numTasks;
   private VertexLocationHint vertexLocationHint;
   private Map<String, EdgeManagerDescriptor> sourceEdgeManagers;
+  private Map<String, RootInputSpecUpdate> rootInputSpecUpdates;
 
   public VertexParallelismUpdatedEvent() {
   }
 
   public VertexParallelismUpdatedEvent(TezVertexID vertexID,
       int numTasks, VertexLocationHint vertexLocationHint,
-      Map<String, EdgeManagerDescriptor> sourceEdgeManagers) {
+      Map<String, EdgeManagerDescriptor> sourceEdgeManagers,
+      Map<String, RootInputSpecUpdate> rootInputSpecUpdates) {
     this.vertexID = vertexID;
     this.numTasks = numTasks;
     this.vertexLocationHint = vertexLocationHint;
     this.sourceEdgeManagers = sourceEdgeManagers;
+    this.rootInputSpecUpdates = rootInputSpecUpdates;
   }
 
   @Override
@@ -88,6 +95,17 @@ public class VertexParallelismUpdatedEvent implements HistoryEvent {
         builder.addEdgeManagerDescriptors(edgeMgrBuilder.build());
       }
     }
+    if (rootInputSpecUpdates != null) {
+      for (Entry<String, RootInputSpecUpdate> entry : rootInputSpecUpdates.entrySet()) {
+        RootInputSpecUpdateProto.Builder rootInputSpecUpdateBuilder = RootInputSpecUpdateProto
+            .newBuilder();
+        rootInputSpecUpdateBuilder.setInputName(entry.getKey());
+        rootInputSpecUpdateBuilder.setForAllWorkUnits(entry.getValue().isForAllWorkUnits());
+        rootInputSpecUpdateBuilder.addAllNumPhysicalInputs(entry.getValue()
+            .getAllNumPhysicalInputs());
+        builder.addRootInputSpecUpdates(rootInputSpecUpdateBuilder.build());
+      }
+    }
     return builder.build();
   }
 
@@ -110,6 +128,20 @@ public class VertexParallelismUpdatedEvent implements HistoryEvent {
             edgeManagerDescriptor);
       }
     }
+    if (proto.getRootInputSpecUpdatesCount() > 0) {
+      this.rootInputSpecUpdates = Maps.newHashMap();
+      for (RootInputSpecUpdateProto rootInputSpecUpdateProto : proto.getRootInputSpecUpdatesList()) {
+        RootInputSpecUpdate specUpdate;
+        if (rootInputSpecUpdateProto.getForAllWorkUnits()) {
+          specUpdate = RootInputSpecUpdate
+              .createAllTaskRootInputSpecUpdate(rootInputSpecUpdateProto.getNumPhysicalInputs(0));
+        } else {
+          specUpdate = RootInputSpecUpdate
+              .createPerTaskRootInputSpecUpdate(rootInputSpecUpdateProto.getNumPhysicalInputsList());
+        }
+        this.rootInputSpecUpdates.put(rootInputSpecUpdateProto.getInputName(), specUpdate);
+      }
+    }
   }
 
   @Override
@@ -133,7 +165,9 @@ public class VertexParallelismUpdatedEvent implements HistoryEvent {
         + ", vertexLocationHint=" +
         (vertexLocationHint == null? "null" : vertexLocationHint)
         + ", edgeManagersCount=" +
-        (sourceEdgeManagers == null? "null" : sourceEdgeManagers.size());
+        (sourceEdgeManagers == null? "null" : sourceEdgeManagers.size()
+        + ", rootInputSpecUpdateCount="
+        + (rootInputSpecUpdates == null ? "null" : rootInputSpecUpdates.size()));
   }
 
   public TezVertexID getVertexID() {
@@ -151,4 +185,8 @@ public class VertexParallelismUpdatedEvent implements HistoryEvent {
   public Map<String, EdgeManagerDescriptor> getSourceEdgeManagers() {
     return sourceEdgeManagers;
   }
+  
+  public Map<String, RootInputSpecUpdate> getRootInputSpecUpdates() {
+    return rootInputSpecUpdates;
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/main/proto/HistoryEvents.proto
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/proto/HistoryEvents.proto b/tez-dag/src/main/proto/HistoryEvents.proto
index 654a2fa..5cbe540 100644
--- a/tez-dag/src/main/proto/HistoryEvents.proto
+++ b/tez-dag/src/main/proto/HistoryEvents.proto
@@ -98,11 +98,18 @@ message EdgeManagerDescriptorProto {
   optional TezEntityDescriptorProto entity_descriptor = 2;
 }
 
+message RootInputSpecUpdateProto {
+  optional string input_name = 1;
+  optional bool for_all_work_units = 2;
+  repeated int32 num_physical_inputs = 3;
+}
+
 message VertexParallelismUpdatedProto {
   optional string vertex_id = 1;
   optional int32 num_tasks = 2;
   optional VertexLocationHintProto vertex_location_hint = 3;
   repeated EdgeManagerDescriptorProto edge_manager_descriptors = 4;
+  repeated RootInputSpecUpdateProto root_input_spec_updates = 5;
 }
 
 message VertexCommitStartedProto {

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
index 31be599..db214e9 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestVertexImpl.java
@@ -35,10 +35,12 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 
 import com.google.protobuf.ByteString;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
@@ -66,6 +68,8 @@ import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.VertexLocationHint;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
+import org.apache.tez.dag.api.VertexManagerPlugin;
+import org.apache.tez.dag.api.VertexManagerPluginContext;
 import org.apache.tez.dag.api.VertexManagerPluginContext.TaskWithLocationHint;
 import org.apache.tez.dag.api.client.VertexStatus;
 import org.apache.tez.dag.api.oldrecords.TaskState;
@@ -120,6 +124,7 @@ import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.OutputCommitter;
 import org.apache.tez.runtime.api.OutputCommitterContext;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
 import org.apache.tez.runtime.api.events.DataMovementEvent;
 import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
@@ -130,6 +135,7 @@ import org.apache.tez.test.VertexManagerPluginForTest;
 import org.apache.tez.runtime.api.impl.EventMetaData;
 import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType;
 import org.apache.tez.runtime.api.impl.GroupInputSpec;
+import org.apache.tez.runtime.api.impl.InputSpec;
 import org.apache.tez.runtime.api.impl.TezEvent;
 import org.junit.After;
 import org.junit.Assert;
@@ -459,6 +465,64 @@ public class TestVertexImpl {
                 .addInEdgeId("e1")
               .build()
         )
+        .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex3")
+                .setType(PlanVertexType.NORMAL)
+                .addInputs(
+                    RootInputLeafOutputProto.newBuilder()
+                        .setInitializerClassName(initializerClassName)
+                        .setName("input3")
+                        .setEntityDescriptor(
+                            TezEntityDescriptorProto.newBuilder()
+                              .setClassName("InputClazz")
+                              .build()
+                        )
+                        .build()
+                    )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(-1)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x3.y3")
+                        .build()
+                )
+                .setVertexManagerPlugin(TezEntityDescriptorProto.newBuilder()
+                    .setClassName(RootInputSpecUpdaterVertexManager.class.getName())
+                    .setUserPayload(ByteString.copyFrom(new byte[] {0})))
+              .build()
+        )
+                .addVertex(
+            VertexPlan.newBuilder()
+                .setName("vertex4")
+                .setType(PlanVertexType.NORMAL)
+                .addInputs(
+                    RootInputLeafOutputProto.newBuilder()
+                        .setInitializerClassName(initializerClassName)
+                        .setName("input4")
+                        .setEntityDescriptor(
+                            TezEntityDescriptorProto.newBuilder()
+                              .setClassName("InputClazz")
+                              .build()
+                        )
+                        .build()
+                    )
+                .setTaskConfig(
+                    PlanTaskConfiguration.newBuilder()
+                        .setNumTasks(-1)
+                        .setVirtualCores(4)
+                        .setMemoryMb(1024)
+                        .setJavaOpts("")
+                        .setTaskModule("x3.y3")
+                        .build()
+                )
+                .setVertexManagerPlugin(TezEntityDescriptorProto.newBuilder()
+                    .setClassName(RootInputSpecUpdaterVertexManager.class.getName())
+                    .setUserPayload(ByteString.copyFrom(new byte[] {1})))
+              .build()
+        )
         .addEdge(
             EdgePlan.newBuilder()
                 .setEdgeDestination(TezEntityDescriptorProto.newBuilder().setClassName("v1_v2"))
@@ -1594,7 +1658,7 @@ public class TestVertexImpl {
     Map<String, EdgeManagerDescriptor> edgeManagerDescriptors =
         Collections.singletonMap(
        v1.getName(), mockEdgeManagerDescriptor);
-    Assert.assertTrue(v3.setParallelism(1, null, edgeManagerDescriptors));
+    Assert.assertTrue(v3.setParallelism(1, null, edgeManagerDescriptors, null));
     Assert.assertTrue(v3.sourceVertices.get(v1).getEdgeManager() instanceof
         EdgeManagerForTest);
     Assert.assertEquals(1, v3.getTotalTasks());
@@ -1662,7 +1726,7 @@ public class TestVertexImpl {
     Map<String, EdgeManagerDescriptor> edgeManagerDescriptors =
         Collections.singletonMap(v3.getName(), edgeManagerDescriptor);
     Assert.assertTrue(v5.setParallelism(v5.getTotalTasks() - 1, null,
-        edgeManagerDescriptors)); // Must decrease.
+        edgeManagerDescriptors, null)); // Must decrease.
 
     VertexImpl v5Impl = (VertexImpl) v5;
 
@@ -2298,8 +2362,8 @@ public class TestVertexImpl {
     Assert.assertEquals(-1, v1.getTotalTasks());
     Assert.assertEquals(VertexState.INITIALIZING, v1.getState());
     // set the parallelism
-    v1.setParallelism(numTasks, null, null);
-    v2.setParallelism(numTasks, null, null);
+    v1.setParallelism(numTasks, null, null, null);
+    v2.setParallelism(numTasks, null, null, null);
     dispatcher.await();
     // parallelism set and vertex starts with pending start event
     Assert.assertEquals(numTasks, v1.getTotalTasks());
@@ -2314,7 +2378,7 @@ public class TestVertexImpl {
     // v3 still initializing with source vertex started. So should start running
     // once num tasks is defined
     Assert.assertEquals(VertexState.INITIALIZING, v3.getState());
-    v3.setParallelism(numTasks, null, null);
+    v3.setParallelism(numTasks, null, null, null);
     dispatcher.await();
     Assert.assertEquals(numTasks, v3.getTotalTasks());
     Assert.assertEquals(VertexState.RUNNING, v3.getState());
@@ -2423,7 +2487,7 @@ public class TestVertexImpl {
     Assert.assertEquals(VertexState.RUNNING, vertices.get("vertex4").getState());
     // change parallelism
     int newNumTasks = 3;
-    v1.setParallelism(newNumTasks, null, null);
+    v1.setParallelism(newNumTasks, null, null, null);
     dispatcher.await();
     Assert.assertEquals(newNumTasks, vertices.get("vertex2").getTotalTasks());
     Assert.assertEquals(newNumTasks, vertices.get("vertex3").getTotalTasks());
@@ -2453,7 +2517,7 @@ public class TestVertexImpl {
     Assert.assertEquals(numTasks, vertices.get("vertex4").getTotalTasks());
     // change parallelism
     int newNumTasks = 3;
-    v1.setParallelism(newNumTasks, null, null);
+    v1.setParallelism(newNumTasks, null, null, null);
     dispatcher.await();
     Assert.assertEquals(newNumTasks, vertices.get("vertex2").getTotalTasks());
     Assert.assertEquals(newNumTasks, vertices.get("vertex3").getTotalTasks());
@@ -2550,6 +2614,11 @@ public class TestVertexImpl {
       Assert.assertEquals(v1Hints.get(i), v1.getTaskLocationHints()[i]);
     }
     Assert.assertEquals(true, runner1.hasShutDown);
+    for (int i = 0; i < 5; i++) {
+      List<InputSpec> inputSpecs = v1.getInputSpecList(i);
+      Assert.assertEquals(1, inputSpecs.size());
+      Assert.assertEquals(1, inputSpecs.get(0).getPhysicalEdgeCount());
+    }
     
     VertexImplWithCustomInitializer v2 = (VertexImplWithCustomInitializer) vertices.get("vertex2");
     Assert.assertEquals(VertexState.INITIALIZING, v2.getState());
@@ -2584,6 +2653,73 @@ public class TestVertexImpl {
       Assert.assertEquals(v2Hints.get(i), v2.getTaskLocationHints()[i]);
     }
     Assert.assertEquals(true, runner2.hasShutDown);
+    for (int i = 0; i < 10; i++) {
+      List<InputSpec> inputSpecs = v1.getInputSpecList(i);
+      Assert.assertEquals(1, inputSpecs.size());
+      Assert.assertEquals(1, inputSpecs.get(0).getPhysicalEdgeCount());
+    }
+  }
+  
+  @SuppressWarnings("unchecked")
+  @Test(timeout = 5000)
+  public void testVertexRootInputSpecUpdateAll() {
+    useCustomInitializer = true;
+    setupPreDagCreation();
+    dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
+    setupPostDagCreation();
+
+    int expectedNumTasks = RootInputSpecUpdaterVertexManager.NUM_TASKS;
+    VertexImplWithCustomInitializer v3 = (VertexImplWithCustomInitializer) vertices
+        .get("vertex3");
+    dispatcher.getEventHandler().handle(
+        new VertexEvent(v3.getVertexId(), VertexEventType.V_INIT));
+    dispatcher.await();
+    Assert.assertEquals(VertexState.INITIALIZING, v3.getState());
+    RootInputInitializerRunnerControlled runner1 = v3.getRootInputInitializerRunner();
+    runner1.completeInputInitialization();
+
+    Assert.assertEquals(VertexState.INITED, v3.getState());
+    Assert.assertEquals(expectedNumTasks, v3.getTotalTasks());
+    Assert.assertEquals(RootInputSpecUpdaterVertexManager.class.getName(), v3.getVertexManager()
+        .getPlugin().getClass().getName());
+    Assert.assertEquals(true, runner1.hasShutDown);
+    
+    for (int i = 0; i < expectedNumTasks; i++) {
+      List<InputSpec> inputSpecs = v3.getInputSpecList(i);
+      Assert.assertEquals(1, inputSpecs.size());
+      Assert.assertEquals(4, inputSpecs.get(0).getPhysicalEdgeCount());
+    }
+  }
+  
+  @SuppressWarnings("unchecked")
+  @Test(timeout = 5000)
+  public void testVertexRootInputSpecUpdatePerTask() {
+    useCustomInitializer = true;
+    setupPreDagCreation();
+    dagPlan = createDAGPlanWithInputInitializer("TestInputInitializer");
+    setupPostDagCreation();
+
+    int expectedNumTasks = RootInputSpecUpdaterVertexManager.NUM_TASKS;
+    VertexImplWithCustomInitializer v4 = (VertexImplWithCustomInitializer) vertices
+        .get("vertex4");
+    dispatcher.getEventHandler().handle(
+        new VertexEvent(v4.getVertexId(), VertexEventType.V_INIT));
+    dispatcher.await();
+    Assert.assertEquals(VertexState.INITIALIZING, v4.getState());
+    RootInputInitializerRunnerControlled runner1 = v4.getRootInputInitializerRunner();
+    runner1.completeInputInitialization();
+
+    Assert.assertEquals(VertexState.INITED, v4.getState());
+    Assert.assertEquals(expectedNumTasks, v4.getTotalTasks());
+    Assert.assertEquals(RootInputSpecUpdaterVertexManager.class.getName(), v4.getVertexManager()
+        .getPlugin().getClass().getName());
+    Assert.assertEquals(true, runner1.hasShutDown);
+    
+    for (int i = 0; i < expectedNumTasks; i++) {
+      List<InputSpec> inputSpecs = v4.getInputSpecList(i);
+      Assert.assertEquals(1, inputSpecs.size());
+      Assert.assertEquals(i + 1, inputSpecs.get(0).getPhysicalEdgeCount());
+    }
   }
   
   private List<TaskLocationHint> createTaskLocationHints(int numTasks) {
@@ -2700,11 +2836,17 @@ public class TestVertexImpl {
       dispatcher.await();
     }
 
+    public void completeInputInitialization() {
+      eventHandler.handle(new VertexEventRootInputInitialized(vertexID, inputs.get(0)
+          .getEntityName(), null));
+      dispatcher.await();
+    }
+    
     public void completeInputInitialization(int targetTasks, List<TaskLocationHint> locationHints) {
       List<Event> events = Lists.newArrayListWithCapacity(targetTasks + 1);
 
       RootInputConfigureVertexTasksEvent configEvent = new RootInputConfigureVertexTasksEvent(
-          targetTasks, locationHints);
+          targetTasks, locationHints, null);
       events.add(configEvent);
       for (int i = 0; i < targetTasks; i++) {
         RootInputDataInformationEvent diEvent = new RootInputDataInformationEvent(
@@ -2840,4 +2982,43 @@ public class TestVertexImpl {
     Assert.assertEquals(VertexState.RUNNING, vB.getState());
     Assert.assertEquals(VertexState.RUNNING, vC.getState());
   }
+  
+  public static class RootInputSpecUpdaterVertexManager implements VertexManagerPlugin {
+
+    private VertexManagerPluginContext context;
+    private static final int NUM_TASKS = 5;
+
+    @Override
+    public void initialize(VertexManagerPluginContext context) {
+      this.context = context;
+    }
+
+    @Override
+    public void onVertexStarted(Map<String, List<Integer>> completions) {
+    }
+
+    @Override
+    public void onSourceTaskCompleted(String srcVertexName, Integer taskId) {
+    }
+
+    @Override
+    public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
+    }
+
+    @Override
+    public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor,
+        List<Event> events) {
+      Map<String, RootInputSpecUpdate> map = new HashMap<String, RootInputSpecUpdate>();
+      if (context.getUserPayload()[0] == 0) {
+        map.put("input3", RootInputSpecUpdate.createAllTaskRootInputSpecUpdate(4));
+      } else {
+        List<Integer> pInputList = new LinkedList<Integer>();
+        for (int i = 1; i <= NUM_TASKS; i++) {
+          pInputList.add(i);
+        }
+        map.put("input4", RootInputSpecUpdate.createPerTaskRootInputSpecUpdate(pInputList));
+      }
+      context.setVertexParallelism(NUM_TASKS, null, null, map);
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-dag/src/test/java/org/apache/tez/dag/history/events/TestHistoryEventsProtoConversion.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/history/events/TestHistoryEventsProtoConversion.java b/tez-dag/src/test/java/org/apache/tez/dag/history/events/TestHistoryEventsProtoConversion.java
index 164bd2f..9f1ad89 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/history/events/TestHistoryEventsProtoConversion.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/history/events/TestHistoryEventsProtoConversion.java
@@ -44,6 +44,7 @@ import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
 import org.apache.tez.dag.recovery.records.RecoveryProtos.SummaryEventProto;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.events.DataMovementEvent;
 import org.apache.tez.runtime.api.impl.EventMetaData;
 import org.apache.tez.runtime.api.impl.EventMetaData.EventProducerConsumerType;
@@ -51,10 +52,13 @@ import org.apache.tez.runtime.api.impl.TezEvent;
 import org.junit.Assert;
 import org.junit.Test;
 
+import com.google.common.collect.Lists;
+
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -261,11 +265,18 @@ public class TestHistoryEventsProtoConversion {
 
   private void testVertexParallelismUpdatedEvent() throws Exception {
     {
+      RootInputSpecUpdate rootInputSpecUpdateBulk = RootInputSpecUpdate
+          .createAllTaskRootInputSpecUpdate(2);
+      RootInputSpecUpdate rootInputSpecUpdatePerTask = RootInputSpecUpdate
+          .createPerTaskRootInputSpecUpdate(Lists.newArrayList(1, 2, 3));
+      Map<String, RootInputSpecUpdate> rootInputSpecUpdates = new HashMap<String, RootInputSpecUpdate>();
+      rootInputSpecUpdates.put("input1", rootInputSpecUpdateBulk);
+      rootInputSpecUpdates.put("input2", rootInputSpecUpdatePerTask);
       VertexParallelismUpdatedEvent event =
           new VertexParallelismUpdatedEvent(
               TezVertexID.getInstance(
                   TezDAGID.getInstance(ApplicationId.newInstance(0, 1), 1), 111),
-              100, null, null);
+              100, null, null, rootInputSpecUpdates);
       VertexParallelismUpdatedEvent deserializedEvent = (VertexParallelismUpdatedEvent)
           testProtoConversion(event);
       Assert.assertEquals(event.getVertexID(), deserializedEvent.getVertexID());
@@ -274,6 +285,18 @@ public class TestHistoryEventsProtoConversion {
           deserializedEvent.getSourceEdgeManagers());
       Assert.assertEquals(event.getVertexLocationHint(),
           deserializedEvent.getVertexLocationHint());
+      Assert.assertEquals(event.getRootInputSpecUpdates().size(), deserializedEvent
+          .getRootInputSpecUpdates().size());
+      RootInputSpecUpdate deserializedBulk = deserializedEvent.getRootInputSpecUpdates().get("input1");
+      RootInputSpecUpdate deserializedPerTask = deserializedEvent.getRootInputSpecUpdates().get("input2");
+      Assert.assertEquals(rootInputSpecUpdateBulk.isForAllWorkUnits(),
+          deserializedBulk.isForAllWorkUnits());
+      Assert.assertEquals(rootInputSpecUpdateBulk.getAllNumPhysicalInputs(),
+          deserializedBulk.getAllNumPhysicalInputs());
+      Assert.assertEquals(rootInputSpecUpdatePerTask.isForAllWorkUnits(),
+          deserializedPerTask.isForAllWorkUnits());
+      Assert.assertEquals(rootInputSpecUpdatePerTask.getAllNumPhysicalInputs(),
+          deserializedPerTask.getAllNumPhysicalInputs());
       logEvents(event, deserializedEvent);
     }
     {
@@ -289,7 +312,7 @@ public class TestHistoryEventsProtoConversion {
               100, new VertexLocationHint(Arrays.asList(new TaskLocationHint(
                   new HashSet<String>(Arrays.asList("h1")),
               new HashSet<String>(Arrays.asList("r1"))))),
-              sourceEdgeManagers);
+              sourceEdgeManagers, null);
 
       VertexParallelismUpdatedEvent deserializedEvent = (VertexParallelismUpdatedEvent)
           testProtoConversion(event);

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
index 1769dbf..52b6b1c 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
@@ -38,6 +38,7 @@ import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRInputUserPayloadProto;
 import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitProto;
 import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitsProto;
 import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.RootInputSpecUpdate;
 import org.apache.tez.runtime.api.TezRootInputInitializer;
 import org.apache.tez.runtime.api.TezRootInputInitializerContext;
 import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
@@ -169,7 +170,8 @@ public class MRInputAMSplitGenerator implements TezRootInputInitializer {
         .getNumTasks() + 1);
     
     RootInputConfigureVertexTasksEvent configureVertexEvent = new RootInputConfigureVertexTasksEvent(
-        inputSplitInfo.getNumTasks(), inputSplitInfo.getTaskLocationHints());
+        inputSplitInfo.getNumTasks(), inputSplitInfo.getTaskLocationHints(),
+        RootInputSpecUpdate.getDefaultSinglePhysicalInputSpecUpdate());
     events.add(configureVertexEvent);
 
     if (sendSerializedEvents) {

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
index d17b367..5b489ed 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/dag/library/vertexmanager/ShuffleVertexManager.java
@@ -427,7 +427,7 @@ public class ShuffleVertexManager implements VertexManagerPlugin {
         edgeManagers.put(vertex, edgeManagerDescriptor);
       }
       
-      context.setVertexParallelism(finalTaskParallelism, null, edgeManagers);
+      context.setVertexParallelism(finalTaskParallelism, null, edgeManagers, null);
       updatePendingTasks();      
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/d3fdd81b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
index 99f4245..fce6bc3 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/dag/library/vertexmanager/TestShuffleVertexManager.java
@@ -147,7 +147,7 @@ public class TestShuffleVertexManager {
             newEdgeManagers.put(entry.getKey(), edgeManager);
           }
           return null;
-      }}).when(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());
+      }}).when(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());
     
     // source vertices have 0 tasks. immediate start of all managed tasks
     when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(0);
@@ -173,7 +173,7 @@ public class TestShuffleVertexManager {
     manager.onVertexManagerEventReceived(vmEvent);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
     // managedVertex tasks reduced
-    verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap());
+    verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap(), anyMap());
     Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
     Assert.assertEquals(4, scheduledTasks.size());
     Assert.assertEquals(1, manager.numSourceTasksCompleted);
@@ -212,7 +212,7 @@ public class TestShuffleVertexManager {
     manager.onVertexManagerEventReceived(vmEvent);
     manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
     // managedVertex tasks reduced
-    verify(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());
+    verify(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());
     Assert.assertEquals(2, newEdgeManagers.size());
     // TODO improve tests for parallelism
     Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
@@ -225,7 +225,7 @@ public class TestShuffleVertexManager {
     
     // more completions dont cause recalculation of parallelism
     manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
-    verify(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());
+    verify(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap(), anyMap());
     Assert.assertEquals(2, newEdgeManagers.size());
     
     EdgeManager edgeManager = newEdgeManagers.values().iterator().next();


Mime
View raw message