tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jl...@apache.org
Subject [3/3] tez git commit: TEZ-3770. DAG-aware YARN task scheduler (jlowe)
Date Thu, 25 Jan 2018 15:44:10 GMT
TEZ-3770. DAG-aware YARN task scheduler (jlowe)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/a9b8bb5a
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/a9b8bb5a
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/a9b8bb5a

Branch: refs/heads/master
Commit: a9b8bb5a6840a89844cd75ec86aab2b7457c8028
Parents: 3c7640d
Author: Jason Lowe <jlowe@apache.org>
Authored: Thu Jan 25 09:35:33 2018 -0600
Committer: Jason Lowe <jlowe@apache.org>
Committed: Thu Jan 25 09:35:33 2018 -0600

----------------------------------------------------------------------
 .../apache/tez/dag/api/TezConfiguration.java    |    9 +
 .../apache/tez/serviceplugins/api/DagInfo.java  |    6 +
 .../api/TaskSchedulerContext.java               |    1 +
 .../tez/common/MockDNSToSwitchMapping.java      |   14 +-
 tez-dag/findbugs-exclude.xml                    |    5 +
 .../java/org/apache/tez/dag/app/dag/DAG.java    |    1 -
 .../apache/tez/dag/app/dag/impl/DAGImpl.java    |   39 +
 .../dag/app/rm/DagAwareYarnTaskScheduler.java   | 2064 ++++++++++++++++++
 .../dag/app/rm/TaskSchedulerContextImpl.java    |    5 +
 .../app/rm/TaskSchedulerContextImplWrapper.java |    5 +
 .../tez/dag/app/rm/TaskSchedulerManager.java    |   24 +-
 .../java/org/apache/tez/dag/app/MockClock.java  |   24 +-
 .../app/rm/TestDagAwareYarnTaskScheduler.java   | 1510 +++++++++++++
 .../dag/app/rm/TestTaskSchedulerHelpers.java    |    5 +
 .../tez/dag/helpers/DagInfoImplForTest.java     |   12 +
 .../ControlledScheduledExecutorService.java     |  239 ++
 16 files changed, 3954 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index efe6d6c..6d3050d 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -680,6 +680,15 @@ public class TezConfiguration extends Configuration {
   public static final String TEZ_AM_DAG_SCHEDULER_CLASS_DEFAULT =
       "org.apache.tez.dag.app.dag.impl.DAGSchedulerNaturalOrder";
 
+  /**
+   * String value. The class to be used for the YARN task scheduler. Expert level setting.
+   */
+  @ConfigurationScope(Scope.AM)
+  @ConfigurationProperty
+  public static final String TEZ_AM_YARN_SCHEDULER_CLASS = TEZ_AM_PREFIX + "yarn.scheduler.class";
+  public static final String TEZ_AM_YARN_SCHEDULER_CLASS_DEFAULT =
+      "org.apache.tez.dag.app.rm.YarnTaskSchedulerService";
+
   /** Int value. The amount of memory in MB to be used by the AppMaster */
   @ConfigurationScope(Scope.AM)
   @ConfigurationProperty(type="integer")

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-api/src/main/java/org/apache/tez/serviceplugins/api/DagInfo.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/serviceplugins/api/DagInfo.java b/tez-api/src/main/java/org/apache/tez/serviceplugins/api/DagInfo.java
index 4a8b9fa..328cb62 100644
--- a/tez-api/src/main/java/org/apache/tez/serviceplugins/api/DagInfo.java
+++ b/tez-api/src/main/java/org/apache/tez/serviceplugins/api/DagInfo.java
@@ -16,6 +16,8 @@ package org.apache.tez.serviceplugins.api;
 
 import org.apache.hadoop.security.Credentials;
 
+import java.util.BitSet;
+
 public interface DagInfo {
 
   /**
@@ -35,4 +37,8 @@ public interface DagInfo {
    * @return the credentials for the dag.
    */
   Credentials getCredentials();
+
+  int getTotalVertices();
+
+  BitSet getVertexDescendants(int vertexIndex);
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-api/src/main/java/org/apache/tez/serviceplugins/api/TaskSchedulerContext.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/serviceplugins/api/TaskSchedulerContext.java b/tez-api/src/main/java/org/apache/tez/serviceplugins/api/TaskSchedulerContext.java
index 97fe7ae..e188231 100644
--- a/tez-api/src/main/java/org/apache/tez/serviceplugins/api/TaskSchedulerContext.java
+++ b/tez-api/src/main/java/org/apache/tez/serviceplugins/api/TaskSchedulerContext.java
@@ -220,4 +220,5 @@ public interface TaskSchedulerContext extends ServicePluginContextBase {
    */
   AMState getAMState();
 
+  int getVertexIndexForTask(Object task);
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-common/src/test/java/org/apache/tez/common/MockDNSToSwitchMapping.java
----------------------------------------------------------------------
diff --git a/tez-common/src/test/java/org/apache/tez/common/MockDNSToSwitchMapping.java b/tez-common/src/test/java/org/apache/tez/common/MockDNSToSwitchMapping.java
index ec6881f..afcd687 100644
--- a/tez-common/src/test/java/org/apache/tez/common/MockDNSToSwitchMapping.java
+++ b/tez-common/src/test/java/org/apache/tez/common/MockDNSToSwitchMapping.java
@@ -19,7 +19,10 @@
 package org.apache.tez.common;
 
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
@@ -32,6 +35,8 @@ import org.apache.hadoop.yarn.util.RackResolver;
  * Overrides CachedDNSToSwitchMapping to ensure that it does not try to resolve hostnames
  */
 public class MockDNSToSwitchMapping extends CachedDNSToSwitchMapping implements DNSToSwitchMapping {
+  private static final Map<String, String> rackMap =
+      Collections.synchronizedMap(new HashMap<String, String>());
 
   private final String defaultRack = "/default-rack";
 
@@ -43,7 +48,11 @@ public class MockDNSToSwitchMapping extends CachedDNSToSwitchMapping implements
   public List<String> resolve(List<String> strings) {
     List<String> resolvedHosts = new ArrayList<String>();
     for (String h : strings) {
-      resolvedHosts.add(defaultRack);
+      String rack = rackMap.get(h);
+      if (rack == null) {
+        rack = defaultRack;
+      }
+      resolvedHosts.add(rack);
     }
     return resolvedHosts;
   }
@@ -62,4 +71,7 @@ public class MockDNSToSwitchMapping extends CachedDNSToSwitchMapping implements
     RackResolver.init(rackResolverConf);
   }
 
+  public static void addRackMapping(String host, String rack) {
+    rackMap.put(host, rack);
+  }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-dag/findbugs-exclude.xml
----------------------------------------------------------------------
diff --git a/tez-dag/findbugs-exclude.xml b/tez-dag/findbugs-exclude.xml
index c3e099e..1150ccb 100644
--- a/tez-dag/findbugs-exclude.xml
+++ b/tez-dag/findbugs-exclude.xml
@@ -93,6 +93,11 @@
   </Match>
 
   <Match>
+    <Class name="org.apache.tez.dag.app.rm.DagAwareYarnTaskScheduler$PreemptOrderComparator"/>
+    <Bug pattern="SE_COMPARATOR_SHOULD_BE_SERIALIZABLE"/>
+  </Match>
+
+  <Match>
     <Class name="org.apache.tez.dag.app.launcher.ContainerLauncherManager" />
     <Method name="handle" params="org.apache.tez.dag.app.rm.ContainerLauncherEvent" returns="void" />
     <Bug pattern="BC_UNCONFIRMED_CAST" />

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/DAG.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/DAG.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/DAG.java
index 6c8e8f9..10c4257 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/DAG.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/DAG.java
@@ -68,7 +68,6 @@ public interface DAG extends DagInfo {
   Map<TezVertexID,Vertex> getVertices();
   Vertex getVertex(TezVertexID vertexId);
   List<String> getDiagnostics();
-  int getTotalVertices();
   int getSuccessfulVertices();
   float getProgress();
   float getCompletedTaskProgress();

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
index 6c67e68..0a775a6 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 import java.net.URL;
 import java.security.PrivilegedExceptionAction;
 import java.util.ArrayList;
+import java.util.BitSet;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.EnumSet;
@@ -190,6 +191,7 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
   volatile Map<TezVertexID, Vertex> vertices = new HashMap<TezVertexID, Vertex>();
   @VisibleForTesting
   Map<String, Edge> edges = new HashMap<String, Edge>();
+  ArrayList<BitSet> vertexDescendants;
   private TezCounters dagCounters = new TezCounters();
   private Object fullCountersLock = new Object();
   @VisibleForTesting
@@ -1458,6 +1460,16 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
   }
 
   @Override
+  public BitSet getVertexDescendants(int vertexIndex) {
+    readLock.lock();
+    try {
+      return vertexDescendants.get(vertexIndex);
+    } finally {
+      readLock.unlock();
+    }
+  }
+
+  @Override
   public int getSuccessfulVertices() {
     readLock.lock();
     try {
@@ -1560,6 +1572,8 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
       parseVertexEdges(this, edgePlans, v);
     }
 
+    computeVertexDescendants();
+
     // Initialize the edges, now that the payload and vertices have been set.
     for (Edge e : edges.values()) {
       try {
@@ -1616,6 +1630,31 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
     }
   }
 
+  private void computeVertexDescendants() {
+    vertexDescendants = new ArrayList<>(numVertices);
+    for (int i = 0; i < numVertices; ++i) {
+      vertexDescendants.add(new BitSet(numVertices));
+    }
+    BitSet verticesVisited = new BitSet(numVertices);
+    for (Vertex v : vertices.values()) {
+      computeVertexDescendants(verticesVisited, v);
+    }
+  }
+
+  private BitSet computeVertexDescendants(BitSet verticesVisited, Vertex v) {
+    int vertexIndex = v.getVertexId().getId();
+    BitSet descendants = vertexDescendants.get(vertexIndex);
+    if (!verticesVisited.get(vertexIndex)) {
+      for (Vertex child : v.getOutputVertices().keySet()) {
+        descendants.set(child.getVertexId().getId());
+        BitSet childDescendants = computeVertexDescendants(verticesVisited, child);
+        descendants.or(childDescendants);
+      }
+      verticesVisited.set(vertexIndex);
+    }
+    return descendants;
+  }
+
   private static void assignDAGScheduler(DAGImpl dag) throws TezException {
     String dagSchedulerClassName = dag.dagConf.get(TezConfiguration.TEZ_AM_DAG_SCHEDULER_CLASS,
         TezConfiguration.TEZ_AM_DAG_SCHEDULER_CLASS_DEFAULT);

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java
new file mode 100644
index 0000000..dab1cad
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/DagAwareYarnTaskScheduler.java
@@ -0,0 +1,2064 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.rm;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.commons.lang.mutable.MutableInt;
+import org.apache.commons.math3.random.RandomDataGenerator;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.util.Time;
+import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerStatus;
+import org.apache.hadoop.yarn.api.records.NodeId;
+import org.apache.hadoop.yarn.api.records.NodeReport;
+import org.apache.hadoop.yarn.api.records.Priority;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.ResourceRequest;
+import org.apache.hadoop.yarn.client.api.AMRMClient;
+import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
+import org.apache.hadoop.yarn.client.api.async.impl.AMRMClientAsyncImpl;
+import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl;
+import org.apache.hadoop.yarn.proto.YarnServiceProtos.SchedulerResourceTypes;
+import org.apache.hadoop.yarn.util.RackResolver;
+import org.apache.hadoop.yarn.util.resource.Resources;
+import org.apache.tez.common.ContainerSignatureMatcher;
+import org.apache.tez.common.TezUtils;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.serviceplugins.api.DagInfo;
+import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
+import org.apache.tez.serviceplugins.api.TaskScheduler;
+import org.apache.tez.serviceplugins.api.TaskSchedulerContext;
+import org.apache.tez.serviceplugins.api.TaskSchedulerContext.AMState;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.NavigableMap;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * A YARN task scheduler that is aware of the dependencies between vertices
+ * in the DAG and takes them into account when deciding how to schedule
+ * and preempt tasks.
+ *
+ * This scheduler makes the assumption that vertex IDs start at 0 and are
+ * densely allocated (i.e.: there are no "gaps" in the vertex ID space).
+  */
+public class DagAwareYarnTaskScheduler extends TaskScheduler
+    implements AMRMClientAsync.CallbackHandler {
+  private static final Logger LOG = LoggerFactory.getLogger(DagAwareYarnTaskScheduler.class);
+  private static final Comparator<HeldContainer> PREEMPT_ORDER_COMPARATOR = new PreemptOrderComparator();
+
+  private final RandomDataGenerator random = new RandomDataGenerator();
+  private AMRMClientAsyncWrapper client;
+  private ScheduledExecutorService reuseExecutor;
+  private ResourceCalculator resourceCalculator;
+  private int numHeartbeats = 0;
+  private Resource totalResources = Resource.newInstance(0, 0);
+  @GuardedBy("this")
+  private Resource allocatedResources = Resource.newInstance(0, 0);
+  private final Set<NodeId> blacklistedNodes = Collections.newSetFromMap(new ConcurrentHashMap<NodeId, Boolean>());
+  private final ContainerSignatureMatcher signatureMatcher;
+  @GuardedBy("this")
+  private final RequestTracker requestTracker = new RequestTracker();
+  @GuardedBy("this")
+  private final Map<ContainerId, HeldContainer> heldContainers = new HashMap<>();
+  @GuardedBy("this")
+  private final IdleContainerTracker idleTracker = new IdleContainerTracker();
+  @GuardedBy("this")
+  private final Map<Object, HeldContainer> taskAssignments = new HashMap<>();
+
+  /** A mapping from the vertex ID to the set of containers assigned to tasks for that vertex */
+  @GuardedBy("this")
+  private final Map<Integer, Set<HeldContainer>> vertexAssignments = new HashMap<>();
+
+  /** If vertex N has at least one task assigned to a container then the corresponding bit at index N is set */
+  @GuardedBy("this")
+  private final BitSet assignedVertices = new BitSet();
+
+  /**
+   * Tracks assigned tasks for released containers so the app can be notified properly when the
+   * container completion event finally arrives.
+   */
+  @GuardedBy("this")
+  private final Map<ContainerId, Object> releasedContainers = new HashMap<>();
+
+  @GuardedBy("this")
+  private final Set<HeldContainer> sessionContainers = new HashSet<>();
+
+  /**
+   * Tracks the set of descendant vertices in the DAG for each vertex.  The BitSet for descendants of vertex N
+   * are at array index N.  If a bit is set at index X in the descendants BitSet then vertex X is a descendant
+   * of vertex N in the DAG.
+   */
+  @GuardedBy("this")
+  private ArrayList<BitSet> vertexDescendants = null;
+
+  private volatile boolean stopRequested = false;
+  private volatile boolean shouldUnregister = false;
+  private volatile boolean hasUnregistered = false;
+
+  // cached configuration parameters
+  private boolean shouldReuseContainers;
+  private boolean reuseRackLocal;
+  private boolean reuseNonLocal;
+  private long localitySchedulingDelay;
+  private long idleContainerTimeoutMin;
+  private long idleContainerTimeoutMax;
+  private int sessionNumMinHeldContainers;
+  private int preemptionPercentage;
+  private int numHeartbeatsBetweenPreemptions;
+  private int lastPreemptionHeartbeat = 0;
+  private long preemptionMaxWaitTime;
+
+  public DagAwareYarnTaskScheduler(TaskSchedulerContext taskSchedulerContext) {
+    super(taskSchedulerContext);
+    signatureMatcher = taskSchedulerContext.getContainerSignatureMatcher();
+  }
+
+  @Override
+  public void initialize() throws Exception {
+    initialize(new AMRMClientAsyncWrapper(new AMRMClientImpl<TaskRequest>(), 1000, this));
+  }
+
+  void initialize(AMRMClientAsyncWrapper client) throws Exception {
+    super.initialize();
+    this.client = client;
+    Configuration conf = TezUtils.createConfFromUserPayload(getContext().getInitialUserPayload());
+    client.init(conf);
+
+    int heartbeatIntervalMax = conf.getInt(
+        TezConfiguration.TEZ_AM_RM_HEARTBEAT_INTERVAL_MS_MAX,
+        TezConfiguration.TEZ_AM_RM_HEARTBEAT_INTERVAL_MS_MAX_DEFAULT);
+    client.setHeartbeatInterval(heartbeatIntervalMax);
+
+    shouldReuseContainers = conf.getBoolean(
+        TezConfiguration.TEZ_AM_CONTAINER_REUSE_ENABLED,
+        TezConfiguration.TEZ_AM_CONTAINER_REUSE_ENABLED_DEFAULT);
+    reuseRackLocal = conf.getBoolean(
+        TezConfiguration.TEZ_AM_CONTAINER_REUSE_RACK_FALLBACK_ENABLED,
+        TezConfiguration.TEZ_AM_CONTAINER_REUSE_RACK_FALLBACK_ENABLED_DEFAULT);
+    reuseNonLocal = conf
+      .getBoolean(
+        TezConfiguration.TEZ_AM_CONTAINER_REUSE_NON_LOCAL_FALLBACK_ENABLED,
+        TezConfiguration.TEZ_AM_CONTAINER_REUSE_NON_LOCAL_FALLBACK_ENABLED_DEFAULT);
+    Preconditions.checkArgument(
+      ((!reuseRackLocal && !reuseNonLocal) || (reuseRackLocal)),
+      "Re-use Rack-Local cannot be disabled if Re-use Non-Local has been"
+      + " enabled");
+
+    localitySchedulingDelay = conf.getLong(
+      TezConfiguration.TEZ_AM_CONTAINER_REUSE_LOCALITY_DELAY_ALLOCATION_MILLIS,
+      TezConfiguration.TEZ_AM_CONTAINER_REUSE_LOCALITY_DELAY_ALLOCATION_MILLIS_DEFAULT);
+    Preconditions.checkArgument(localitySchedulingDelay >= 0,
+        "Locality Scheduling delay should be >=0");
+
+    idleContainerTimeoutMin = conf.getLong(
+        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MIN_MILLIS,
+        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MIN_MILLIS_DEFAULT);
+    Preconditions.checkArgument(idleContainerTimeoutMin >= 0 || idleContainerTimeoutMin == -1,
+      "Idle container release min timeout should be either -1 or >=0");
+
+    idleContainerTimeoutMax = conf.getLong(
+        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MAX_MILLIS,
+        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MAX_MILLIS_DEFAULT);
+    Preconditions.checkArgument(
+        idleContainerTimeoutMax >= 0 && idleContainerTimeoutMax >= idleContainerTimeoutMin,
+        "Idle container release max timeout should be >=0 and >= " +
+        TezConfiguration.TEZ_AM_CONTAINER_IDLE_RELEASE_TIMEOUT_MIN_MILLIS);
+
+    sessionNumMinHeldContainers = conf.getInt(TezConfiguration.TEZ_AM_SESSION_MIN_HELD_CONTAINERS,
+        TezConfiguration.TEZ_AM_SESSION_MIN_HELD_CONTAINERS_DEFAULT);
+    Preconditions.checkArgument(sessionNumMinHeldContainers >= 0,
+        "Session minimum held containers should be >=0");
+
+    preemptionPercentage = conf.getInt(TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE,
+        TezConfiguration.TEZ_AM_PREEMPTION_PERCENTAGE_DEFAULT);
+    Preconditions.checkArgument(preemptionPercentage >= 0 && preemptionPercentage <= 100,
+        "Preemption percentage should be between 0-100");
+
+    numHeartbeatsBetweenPreemptions = conf.getInt(
+        TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS,
+        TezConfiguration.TEZ_AM_PREEMPTION_HEARTBEATS_BETWEEN_PREEMPTIONS_DEFAULT);
+    Preconditions.checkArgument(numHeartbeatsBetweenPreemptions >= 1,
+        "Heartbeats between preemptions should be >=1");
+
+    preemptionMaxWaitTime = conf.getInt(TezConfiguration.TEZ_AM_PREEMPTION_MAX_WAIT_TIME_MS,
+        TezConfiguration.TEZ_AM_PREEMPTION_MAX_WAIT_TIME_MS_DEFAULT);
+    Preconditions.checkArgument(preemptionMaxWaitTime >=0, "Preemption max wait time must be >=0");
+
+    LOG.info("scheduler initialized with maxRMHeartbeatInterval:" + heartbeatIntervalMax +
+            " reuseEnabled:" + shouldReuseContainers +
+            " reuseRack:" + reuseRackLocal +
+            " reuseAny:" + reuseNonLocal +
+            " localityDelay:" + localitySchedulingDelay +
+            " preemptPercentage:" + preemptionPercentage +
+            " preemptMaxWaitTime:" + preemptionMaxWaitTime +
+            " numHeartbeatsBetweenPreemptions:" + numHeartbeatsBetweenPreemptions +
+            " idleContainerMinTimeout:" + idleContainerTimeoutMin +
+            " idleContainerMaxTimeout:" + idleContainerTimeoutMax +
+            " sessionMinHeldContainers:" + sessionNumMinHeldContainers);
+  }
+
+  @Override
+  public void start() throws Exception {
+    super.start();
+    client.start();
+    if (shouldReuseContainers) {
+      reuseExecutor = createExecutor();
+    }
+    TaskSchedulerContext ctx = getContext();
+    RegisterApplicationMasterResponse response = client.registerApplicationMaster(
+        ctx.getAppHostName(), ctx.getAppClientPort(), ctx.getAppTrackingUrl());
+    ctx.setApplicationRegistrationData(response.getMaximumResourceCapability(),
+        response.getApplicationACLs(), response.getClientToAMTokenMasterKey(),
+        response.getQueue());
+    if (response.getSchedulerResourceTypes().contains(SchedulerResourceTypes.CPU)) {
+      resourceCalculator = new MemCpuResourceCalculator();
+    } else {
+      resourceCalculator = new MemResourceCalculator();
+    }
+  }
+
+  protected ScheduledExecutorService createExecutor() {
+    return new ReuseContainerExecutor();
+  }
+
+  protected long now() {
+    return Time.monotonicNow();
+  }
+
+  @Override
+  public void initiateStop() {
+    super.initiateStop();
+    LOG.debug("Initiating stop of task scheduler");
+    stopRequested = true;
+    List<ContainerId> releasedLaunchedContainers;
+    synchronized (this) {
+      releasedLaunchedContainers = new ArrayList<>(heldContainers.size());
+      List<HeldContainer> heldList = new ArrayList<>(heldContainers.values());
+      for (HeldContainer hc : heldList) {
+        if (releaseContainer(hc)) {
+          releasedLaunchedContainers.add(hc.getId());
+        }
+      }
+
+      List<Object> tasks = requestTracker.getTasks();
+      for (Object task : tasks) {
+        removeTaskRequest(task);
+      }
+    }
+
+    // perform app callback outside of locks
+    for (ContainerId id : releasedLaunchedContainers) {
+      getContext().containerBeingReleased(id);
+    }
+  }
+
+  @Override
+  public void shutdown() throws Exception {
+    super.shutdown();
+    if (reuseExecutor != null) {
+      reuseExecutor.shutdown();
+      reuseExecutor.awaitTermination(2, TimeUnit.SECONDS);
+    }
+    synchronized (this) {
+      if (shouldUnregister && !hasUnregistered) {
+          TaskSchedulerContext.AppFinalStatus status = getContext().getFinalAppStatus();
+          LOG.info("Unregistering from RM, exitStatus={} exitMessage={} trackingURL={}",
+              status.exitStatus, status.exitMessage, status.postCompletionTrackingUrl);
+          client.unregisterApplicationMaster(status.exitStatus,
+              status.exitMessage,
+              status.postCompletionTrackingUrl);
+          hasUnregistered = true;
+      }
+    }
+    client.stop();
+  }
+
+  @Override
+  public void onContainersAllocated(List<Container> containers) {
+    AMState appState = getContext().getAMState();
+    if (stopRequested || appState == AMState.COMPLETED) {
+      LOG.info("Ignoring {} allocations since app is terminating", containers.size());
+      for (Container c : containers) {
+        client.releaseAssignedContainer(c.getId());
+      }
+      return;
+    }
+    List<Assignment> assignments = assignNewContainers(containers, getContext().getAMState(), getContext().isSession());
+    informAppAboutAssignments(assignments);
+  }
+
+  private synchronized List<Assignment> assignNewContainers(List<Container> newContainers,
+      AMState appState, boolean isSession) {
+    // try to assign the containers as node-local
+    List<Assignment> assignments = new ArrayList<>(newContainers.size());
+    List<HeldContainer> unassigned = new ArrayList<>(newContainers.size());
+    for (Container c : newContainers) {
+      HeldContainer hc = new HeldContainer(c);
+      heldContainers.put(hc.getId(), hc);
+      Resources.addTo(allocatedResources, c.getResource());
+      tryAssignNewContainer(hc, hc.getHost(), assignments, unassigned);
+    }
+
+    // try to assign the remaining containers as rack-local
+    List<HeldContainer> containers = unassigned;
+    unassigned = new ArrayList<>(containers.size());
+    for (HeldContainer hc : containers) {
+      tryAssignNewContainer(hc, hc.getRack(), assignments, unassigned);
+    }
+
+    // try to assign the remaining containers without locality
+    containers = unassigned;
+    unassigned = new ArrayList<>(containers.size());
+    for (HeldContainer hc : containers) {
+      tryAssignNewContainer(hc, ResourceRequest.ANY, assignments, unassigned);
+    }
+
+    for (HeldContainer hc : unassigned) {
+      if (shouldReuseContainers) {
+        idleTracker.add(hc);
+        TaskRequest assigned = tryAssignReuseContainer(hc, appState, isSession);
+        if (assigned != null) {
+          assignments.add(new Assignment(assigned, hc.getContainer()));
+        }
+      } else {
+        releaseContainer(hc);
+      }
+    }
+
+    return assignments;
+  }
+
+  /**
+   * Try to assign a newly acquired container to a task of the same priority.
+   *
+   * @param hc the container to assign
+   * @param location the locality to consider for assignment
+   * @param assignments list to update if container is assigned
+   * @param unassigned list to update if container is not assigned
+   */
+  @GuardedBy("this")
+  private void tryAssignNewContainer(HeldContainer hc, String location,
+      List<Assignment> assignments, List<HeldContainer> unassigned) {
+    List<? extends Collection<TaskRequest>> results = client.getMatchingRequests(hc.getPriority(),
+        location, hc.getCapability());
+    if (!results.isEmpty()) {
+      for (Collection<TaskRequest> requests : results) {
+        if (!requests.isEmpty()) {
+          TaskRequest request = requests.iterator().next();
+          assignContainer(request, hc, location);
+          assignments.add(new Assignment(request, hc.getContainer()));
+          return;
+        }
+      }
+    }
+
+    unassigned.add(hc);
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private TaskRequest tryAssignReuseContainer(HeldContainer hc,
+      AMState appState, boolean isSession) {
+    if (stopRequested) {
+      return null;
+    }
+
+    TaskRequest assignedRequest = null;
+    switch (appState) {
+    case IDLE:
+      handleReuseContainerWhenIdle(hc, isSession);
+      break;
+    case RUNNING_APP:
+      if (requestTracker.isEmpty()) {
+        // treat no requests as if app is idle
+        handleReuseContainerWhenIdle(hc, isSession);
+      } else {
+        assignedRequest = tryAssignReuseContainerAppRunning(hc);
+        if (assignedRequest == null) {
+          if (hc.atMaxMatchLevel()) {
+            LOG.info("Releasing idle container {} due to pending requests", hc.getId());
+            releaseContainer(hc);
+          } else {
+            hc.scheduleForReuse(localitySchedulingDelay);
+          }
+        }
+      }
+      break;
+    case COMPLETED:
+      LOG.info("Releasing container {} because app has completed", hc.getId());
+      releaseContainer(hc);
+      break;
+    default:
+      throw new IllegalStateException("Unexpected app state " + appState);
+    }
+
+    return assignedRequest;
+  }
+
+  @GuardedBy("this")
+  private void handleReuseContainerWhenIdle(HeldContainer hc, boolean isSession) {
+    if (isSession && sessionContainers.isEmpty() && sessionNumMinHeldContainers > 0) {
+      computeSessionContainers();
+    }
+
+    if (sessionContainers.contains(hc)) {
+      LOG.info("Retaining container {} since it is a session container");
+      hc.resetMatchingLevel();
+    } else {
+      long now = now();
+      long expiration = hc.getIdleExpirationTimestamp(now);
+      if (now >= expiration) {
+        LOG.info("Releasing expired idle container {}", hc.getId());
+        releaseContainer(hc);
+      } else {
+        hc.scheduleForReuse(expiration - now);
+      }
+    }
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private TaskRequest tryAssignReuseContainerAppRunning(HeldContainer hc) {
+    if (!hc.isAssignable()) {
+      LOG.debug("Skipping scheduling of container {} because it state is {}", hc.getId(), hc.getState());
+      return null;
+    }
+
+    TaskRequest assignedRequest = tryAssignReuseContainerForAffinity(hc);
+    if (assignedRequest != null) {
+      return assignedRequest;
+    }
+
+    for (Entry<Priority,RequestPriorityStats> entry : requestTracker.getStatsEntries()) {
+      Priority priority = entry.getKey();
+      RequestPriorityStats stats = entry.getValue();
+      if (!stats.allowedVertices.intersects(stats.vertices)) {
+        LOG.debug("Skipping requests at priority {} because all requesting vertices are blocked by higher priority requests",
+            priority);
+        continue;
+      }
+
+      String matchLocation = hc.getMatchingLocation();
+      if (stats.localityCount <= 0) {
+        LOG.debug("Overriding locality match of container {} to ANY since there are no locality requests at priority {}",
+            hc.getId(), priority);
+        matchLocation = ResourceRequest.ANY;
+      }
+      assignedRequest = tryAssignReuseContainerForPriority(hc, matchLocation,
+          priority, stats.allowedVertices);
+      if (assignedRequest != null) {
+        break;
+      }
+    }
+    return assignedRequest;
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private TaskRequest tryAssignReuseContainerForAffinity(HeldContainer hc) {
+    Collection<TaskRequest> affinities = hc.getAffinities();
+    if (affinities != null) {
+      for (TaskRequest request : affinities) {
+        if (requestTracker.isRequestBlocked(request)) {
+          LOG.debug("Cannot assign task {} to container {} since vertex {} is a descendant of pending tasks",
+              request.getTask(), hc.getId(), request.getVertexIndex());
+        } else {
+          assignContainer(request, hc, hc.getId());
+          return request;
+        }
+      }
+    }
+    return null;
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private TaskRequest tryAssignReuseContainerForPriority(HeldContainer hc, String matchLocation,
+      Priority priority, BitSet allowedVertices) {
+    List<? extends Collection<TaskRequest>> results = client.getMatchingRequests(priority, matchLocation, hc.getCapability());
+    if (results.isEmpty()) {
+      return null;
+    }
+
+    for (Collection<TaskRequest> requests : results) {
+      for (TaskRequest request : requests) {
+        final int vertexIndex = request.getVertexIndex();
+        if (!allowedVertices.get(vertexIndex)) {
+          LOG.debug("Not assigning task {} since it is a descendant of a pending vertex", request.getTask());
+          continue;
+        }
+
+        Object signature = hc.getSignature();
+        if (signature == null || signatureMatcher.isSuperSet(signature, request.getContainerSignature())) {
+          assignContainer(request, hc, matchLocation);
+          return request;
+        }
+      }
+    }
+    return null;
+  }
+
+  private void informAppAboutAssignments(List<Assignment> assignments) {
+    if (!assignments.isEmpty()) {
+      for (Assignment a : assignments) {
+        informAppAboutAssignment(a.request, a.container);
+      }
+    }
+  }
+
+  /**
+   * Inform the app about a task assignment.  This should not be called with
+   * any locks held.
+   *
+   * @param request the corresponding task request
+   * @param container the container assigned to the task
+   */
+  private void informAppAboutAssignment(TaskRequest request, Container container) {
+    if (blacklistedNodes.contains(container.getNodeId())) {
+      Object task = request.getTask();
+      LOG.info("Container {} allocated for task {} on blacklisted node {}",
+          container.getId(), container.getNodeId(), task);
+      deallocateContainer(container.getId());
+      // its ok to submit the same request again because the RM will not give us
+      // the bad/unhealthy nodes again. The nodes may become healthy/unblacklisted
+      // and so its better to give the RM the full information.
+      allocateTask(task, request.getCapability(),
+          (request.getNodes() == null ? null :
+              request.getNodes().toArray(new String[request.getNodes().size()])),
+          (request.getRacks() == null ? null :
+              request.getRacks().toArray(new String[request.getRacks().size()])),
+          request.getPriority(),
+          request.getContainerSignature(),
+          request.getCookie());
+    } else {
+      getContext().taskAllocated(request.getTask(), request.getCookie(), container);
+    }
+  }
+
+  @GuardedBy("this")
+  private void computeSessionContainers() {
+    Map<String, MutableInt> rackHeldNumber = new HashMap<>();
+    Map<String, List<HeldContainer>> nodeHeldContainers = new HashMap<>();
+    for(HeldContainer heldContainer : heldContainers.values()) {
+      if (heldContainer.getSignature() == null) {
+        // skip containers that have not been launched as there is no process to reuse
+        continue;
+      }
+      MutableInt count = rackHeldNumber.get(heldContainer.getRack());
+      if (count == null) {
+        count = new MutableInt(0);
+        rackHeldNumber.put(heldContainer.getRack(), count);
+      }
+      count.increment();
+      String host = heldContainer.getHost();
+      List<HeldContainer> nodeContainers = nodeHeldContainers.get(host);
+      if (nodeContainers == null) {
+        nodeContainers = new LinkedList<>();
+        nodeHeldContainers.put(host, nodeContainers);
+      }
+      nodeContainers.add(heldContainer);
+    }
+
+    Map<String, MutableInt> rackToHoldNumber = new HashMap<>();
+    for (String rack : rackHeldNumber.keySet()) {
+      rackToHoldNumber.put(rack, new MutableInt(0));
+    }
+
+    // distribute evenly across nodes
+    // the loop assigns 1 container per rack over all racks
+    int containerCount = 0;
+    while (containerCount < sessionNumMinHeldContainers && !rackHeldNumber.isEmpty()) {
+      Iterator<Entry<String, MutableInt>> iter = rackHeldNumber.entrySet().iterator();
+      while (containerCount < sessionNumMinHeldContainers && iter.hasNext()) {
+        Entry<String, MutableInt> entry = iter.next();
+        MutableInt rackCount = entry.getValue();
+        rackCount.decrement();
+        if (rackCount.intValue() >=0) {
+          containerCount++;
+          rackToHoldNumber.get(entry.getKey()).increment();
+        } else {
+          iter.remove();
+        }
+      }
+    }
+
+    // distribute containers evenly across nodes while not exceeding rack limit
+    // the loop assigns 1 container per node over all nodes
+    containerCount = 0;
+    while (containerCount < sessionNumMinHeldContainers && !nodeHeldContainers.isEmpty()) {
+      Iterator<Entry<String, List<HeldContainer>>> iter = nodeHeldContainers.entrySet().iterator();
+      while (containerCount < sessionNumMinHeldContainers && iter.hasNext()) {
+        List<HeldContainer> nodeContainers = iter.next().getValue();
+        if (nodeContainers.isEmpty()) {
+          // node is empty. remove it.
+          iter.remove();
+          continue;
+        }
+        HeldContainer heldContainer = nodeContainers.remove(nodeContainers.size() - 1);
+        MutableInt holdCount = rackToHoldNumber.get(heldContainer.getRack());
+        holdCount.decrement();
+        if (holdCount.intValue() >= 0) {
+          // rack can hold a container
+          containerCount++;
+          sessionContainers.add(heldContainer);
+        } else {
+          // rack limit reached. remove node.
+          iter.remove();
+        }
+      }
+    }
+
+    LOG.info("Identified {} session containers out of {} total containers",
+        sessionContainers.size(), heldContainers.size());
+  }
+
+  @GuardedBy("this")
+  private void activateSessionContainers() {
+    if (!sessionContainers.isEmpty()) {
+      for (HeldContainer hc : sessionContainers) {
+        if (hc.isAssignable()) {
+          hc.scheduleForReuse(localitySchedulingDelay);
+        }
+      }
+      sessionContainers.clear();
+    }
+  }
+
+  @Override
+  public void onContainersCompleted(List<ContainerStatus> statuses) {
+    if (stopRequested) {
+      return;
+    }
+
+    List<TaskStatus> taskStatusList = new ArrayList<>(statuses.size());
+    synchronized (this) {
+      for (ContainerStatus status : statuses) {
+        ContainerId cid = status.getContainerId();
+        LOG.info("Container {} completed with status {}", cid, status);
+        Object task = releasedContainers.remove(cid);
+        if (task == null) {
+          HeldContainer hc = heldContainers.get(cid);
+          if (hc != null) {
+            task = containerCompleted(hc);
+          }
+        }
+        if (task != null) {
+          taskStatusList.add(new TaskStatus(task, status));
+        }
+      }
+    }
+
+    // perform app callback outside of locks
+    for (TaskStatus taskStatus : taskStatusList) {
+      getContext().containerCompleted(taskStatus.task, taskStatus.status);
+    }
+  }
+
+  @Override
+  public void onNodesUpdated(List<NodeReport> updatedNodes) {
+    if (!stopRequested) {
+      getContext().nodesUpdated(updatedNodes);
+    }
+  }
+
+  @Override
+  public float getProgress() {
+    if (stopRequested) {
+      return 1;
+    }
+
+    Collection<ContainerId> preemptedContainers;
+    synchronized (this) {
+      Resource freeResources = getAvailableResources();
+      if (totalResources.getMemory() == 0) {
+        // assume this is the first allocate callback. nothing is allocated.
+        // available resource = totalResource
+        // TODO this will not handle dynamic changes in resources
+        totalResources = Resources.clone(freeResources);
+        LOG.info("App total resource memory: {} cpu: {} activeAssignments: {}",
+            totalResources.getMemory(), totalResources.getVirtualCores(), taskAssignments.size());
+      }
+
+      ++numHeartbeats;
+      if (LOG.isDebugEnabled() || numHeartbeats % 50 == 1) {
+        LOG.info(constructPeriodicLog(freeResources));
+      }
+
+      preemptedContainers = maybePreempt(freeResources);
+      if (preemptedContainers != null && !preemptedContainers.isEmpty()) {
+        lastPreemptionHeartbeat = numHeartbeats;
+      }
+    }
+
+    // perform app callback outside of locks
+    if (preemptedContainers != null && !preemptedContainers.isEmpty()) {
+      for (ContainerId cid : preemptedContainers) {
+        LOG.info("Preempting container {} currently allocated to a task", cid);
+        getContext().preemptContainer(cid);
+      }
+    }
+
+    return getContext().getProgress();
+  }
+
+  @Override
+  public void onShutdownRequest() {
+    if (!stopRequested) {
+      getContext().appShutdownRequested();
+    }
+  }
+
+  @Override
+  public void onError(Throwable e) {
+    LOG.error("Error from ARMRMClient", e);
+    if (!stopRequested) {
+      getContext().reportError(YarnTaskSchedulerServiceError.RESOURCEMANAGER_ERROR,
+          StringUtils.stringifyException(e), null);
+    }
+  }
+
+  @Override
+  public Resource getAvailableResources() {
+    return client.getAvailableResources();
+  }
+
+  @Override
+  public Resource getTotalResources() {
+    return totalResources;
+  }
+
+  @Override
+  public int getClusterNodeCount() {
+    return client.getClusterNodeCount();
+  }
+
+  @Override
+  public synchronized void blacklistNode(NodeId nodeId) {
+    LOG.info("Blacklisting node: {}", nodeId);
+    blacklistedNodes.add(nodeId);
+    client.updateBlacklist(Collections.singletonList(nodeId.getHost()), null);
+  }
+
+  @Override
+  public synchronized void unblacklistNode(NodeId nodeId) {
+    if (blacklistedNodes.remove(nodeId)) {
+      LOG.info("Removing blacklist for node: {}", nodeId);
+      client.updateBlacklist(null, Collections.singletonList(nodeId.getHost()));
+    }
+  }
+
+  @Override
+  public void allocateTask(Object task, Resource capability, String[] hosts, String[] racks,
+      Priority priority, Object containerSignature, Object clientCookie) {
+    int vertexIndex = getContext().getVertexIndexForTask(task);
+    TaskRequest request = new TaskRequest(task, vertexIndex, capability, hosts, racks,
+        priority, containerSignature, clientCookie);
+    addTaskRequest(request);
+  }
+
+  @Override
+  public void allocateTask(Object task, Resource capability, ContainerId containerId,
+      Priority priority, Object containerSignature, Object clientCookie) {
+    String[] hosts = null;
+    synchronized (this) {
+      HeldContainer held = heldContainers.get(containerId);
+      if (held != null) {
+        if (held.canFit(capability)) {
+          hosts = new String[]{held.getHost()};
+        } else {
+          LOG.warn("Match request to container {} but {} does not fit in {}",
+              containerId, capability, held.getCapability());
+          containerId = null;
+        }
+      } else {
+        LOG.info("Ignoring match request to unknown container {}", containerId);
+        containerId = null;
+      }
+    }
+    int vertexIndex = getContext().getVertexIndexForTask(task);
+    TaskRequest request = new TaskRequest(task, vertexIndex, capability, hosts, null,
+        priority, containerSignature, clientCookie, containerId);
+    addTaskRequest(request);
+  }
+
+  @Override
+  public boolean deallocateTask(Object task, boolean taskSucceeded,
+      TaskAttemptEndReason endReason, String diagnostics) {
+    ContainerId releasedLaunchedContainer = null;
+    AMState appState = getContext().getAMState();
+    boolean isSession = getContext().isSession();
+    TaskRequest newAssignment = null;
+    HeldContainer hc;
+    synchronized (this) {
+      TaskRequest request = removeTaskRequest(task);
+      if (request != null) {
+        LOG.debug("Deallocating task {} before it was allocated", task);
+        return false;
+      }
+
+      hc = removeTaskAssignment(task);
+      if (hc != null) {
+        if (taskSucceeded && shouldReuseContainers) {
+          idleTracker.add(hc);
+          newAssignment = tryAssignReuseContainer(hc, appState, isSession);
+          if (newAssignment == null && hc.isReleasedAndUsed()) {
+            releasedLaunchedContainer = hc.getId();
+          }
+        } else {
+          if (releaseContainer(hc)) {
+            releasedLaunchedContainer = hc.getId();
+          }
+        }
+      }
+    }
+
+    // perform app callback outside of locks
+    if (newAssignment != null) {
+      informAppAboutAssignment(newAssignment, hc.getContainer());
+      return true;
+    }
+    if (releasedLaunchedContainer != null) {
+      getContext().containerBeingReleased(releasedLaunchedContainer);
+      return true;
+    }
+    return hc != null;
+  }
+
+  @Override
+  public Object deallocateContainer(ContainerId containerId) {
+    Object task = null;
+    ContainerId releasedLaunchedContainer = null;
+    synchronized (this) {
+      HeldContainer hc = heldContainers.remove(containerId);
+      if (hc != null) {
+        task = hc.getAssignedTask();
+        if (task != null) {
+          LOG.info("Deallocated container {} from task {}", containerId, task);
+        }
+        if (releaseContainer(hc)) {
+          releasedLaunchedContainer = hc.getId();
+        }
+      } else {
+        LOG.info("Ignoring deallocation of unknown container {}", containerId);
+      }
+    }
+
+    // perform app callback outside of locks
+    if (releasedLaunchedContainer != null) {
+      getContext().containerBeingReleased(releasedLaunchedContainer);
+    }
+    return task;
+  }
+
+  @GuardedBy("this")
+  private void assignContainer(TaskRequest request, HeldContainer hc, Object match) {
+    LOG.info("Assigning container {} to task {} host={} priority={} capability={} match={} lastTask={}",
+        hc.getId(), request.getTask(), hc.getHost(), hc.getPriority(), hc.getCapability(), match, hc.getLastTask());
+    removeTaskRequest(request.getTask());
+    addTaskAssignment(request, hc);
+    idleTracker.remove(hc);
+  }
+
+  private synchronized boolean releaseContainer(HeldContainer hc) {
+    Object task = containerCompleted(hc);
+    client.releaseAssignedContainer(hc.getId());
+    if (task != null) {
+      releasedContainers.put(hc.getId(), task);
+      return true;
+    }
+    return false;
+  }
+
+  @GuardedBy("this")
+  private void addTaskAssignment(TaskRequest request, HeldContainer hc) {
+    HeldContainer oldContainer = taskAssignments.put(request.getTask(), hc);
+    if (oldContainer != null) {
+      LOG.error("Task {} being assigned to container {} but was already assigned to container {}",
+          request.getTask(), hc.getId(), oldContainer.getId());
+    }
+    Integer vertexIndex = request.vertexIndex;
+    Set<HeldContainer> cset = vertexAssignments.get(vertexIndex);
+    if (cset == null) {
+      cset = new HashSet<>();
+      vertexAssignments.put(vertexIndex, cset);
+      assignedVertices.set(vertexIndex);
+    }
+    cset.add(hc);
+    hc.assignTask(request);
+  }
+
+  @GuardedBy("this")
+  private HeldContainer removeTaskAssignment(Object task) {
+    HeldContainer hc = taskAssignments.remove(task);
+    if (hc != null) {
+      TaskRequest request = hc.removeAssignment();
+      if (request != null) {
+        Integer vertexIndex = request.vertexIndex;
+        Set<HeldContainer> cset = vertexAssignments.get(vertexIndex);
+        if (cset != null && cset.remove(hc) && cset.isEmpty()) {
+          vertexAssignments.remove(vertexIndex);
+          assignedVertices.clear(vertexIndex);
+        }
+      } else {
+        LOG.error("Container {} had assigned task {} but no request?!?", hc.getId(), task);
+      }
+    }
+    return hc;
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private Object containerCompleted(HeldContainer hc) {
+    idleTracker.remove(hc);
+    heldContainers.remove(hc.getId());
+    Resources.subtractFrom(allocatedResources, hc.getCapability());
+    removeTaskAssignment(hc.getAssignedTask());
+    hc.released();
+    return hc.getLastTask();
+  }
+
+  @GuardedBy("this")
+  private void ensureVertexDescendants() {
+    if (vertexDescendants == null) {
+      DagInfo info = getContext().getCurrentDagInfo();
+      if (info == null) {
+        throw new IllegalStateException("Scheduling tasks but no current DAG info?");
+      }
+      int numVertices = info.getTotalVertices();
+      ArrayList<BitSet> descendants = new ArrayList<>(numVertices);
+      for (int i = 0; i < numVertices; ++i) {
+        descendants.add(info.getVertexDescendants(i));
+      }
+      vertexDescendants = descendants;
+    }
+  }
+
+  private void addTaskRequest(TaskRequest request) {
+    Container assignedContainer = null;
+    synchronized (this) {
+      if (shouldReuseContainers && !stopRequested && getContext().getAMState() != AMState.COMPLETED) {
+        ensureVertexDescendants();
+        activateSessionContainers();
+        HeldContainer hc = tryAssignTaskToIdleContainer(request);
+        if (hc != null) {
+          assignedContainer = hc.getContainer();
+        }
+      }
+
+      if (assignedContainer == null) {
+        ensureVertexDescendants();
+        TaskRequest old = requestTracker.add(request);
+        if (old != null) {
+          removeTaskRequestByRequest(request);
+        }
+        client.addContainerRequest(request);
+
+        HeldContainer hc = heldContainers.get(request.getAffinity());
+        if (hc != null) {
+          hc.addAffinity(request);
+        }
+      }
+    }
+
+    // perform app callback outside of locks
+    if (assignedContainer != null) {
+      informAppAboutAssignment(request, assignedContainer);
+    }
+  }
+
+  @Nullable
+  private synchronized TaskRequest removeTaskRequest(Object task) {
+    TaskRequest request = requestTracker.remove(task);
+    if (request != null) {
+      removeTaskRequestByRequest(request);
+    }
+    return request;
+  }
+
+  @GuardedBy("this")
+  private void removeTaskRequestByRequest(TaskRequest request) {
+    client.removeContainerRequest(request);
+    HeldContainer hc = heldContainers.get(request.getAffinity());
+    if (hc != null) {
+      hc.removeAffinity(request);
+    }
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private HeldContainer tryAssignTaskToIdleContainer(TaskRequest request) {
+    if (requestTracker.isRequestBlocked(request)) {
+      LOG.debug("Cannot assign task {} to an idle container since vertex {} is a descendant of pending tasks",
+          request.getTask(), request.getVertexIndex());
+      return null;
+    }
+
+    // check if container affinity can be satisfied immediately
+    ContainerId affinity = request.getAffinity();
+    if (affinity != null) {
+      HeldContainer hc = heldContainers.get(affinity);
+      if (hc != null && hc.isAssignable()) {
+        assignContainer(request, hc, affinity);
+        return hc;
+      }
+    }
+
+    // try to match the task against idle containers in order from best locality to worst
+    HeldContainer hc;
+    if (request.hasLocality()) {
+      hc = tryAssignTaskToIdleContainer(request, request.getNodes(), HeldContainerState.MATCHES_LOCAL_STATES);
+      if (hc == null) {
+        hc = tryAssignTaskToIdleContainer(request, request.getRacks(), HeldContainerState.MATCHES_RACK_STATES);
+        if (hc == null) {
+          hc = tryAssignTaskToIdleContainer(request, ResourceRequest.ANY, HeldContainerState.MATCHES_ANY_STATES);
+        }
+      }
+    } else {
+      hc = tryAssignTaskToIdleContainer(request, ResourceRequest.ANY, HeldContainerState.MATCHES_LOCAL_STATES);
+    }
+
+    return hc;
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private HeldContainer tryAssignTaskToIdleContainer(TaskRequest request,
+      List<String> locations, EnumSet<HeldContainerState> eligibleStates) {
+    if (locations != null && !locations.isEmpty()) {
+      for (String location : locations) {
+        HeldContainer hc = tryAssignTaskToIdleContainer(request, location, eligibleStates);
+        if (hc != null) {
+          return hc;
+        }
+      }
+    }
+    return null;
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private HeldContainer tryAssignTaskToIdleContainer(TaskRequest request,
+      String location, EnumSet<HeldContainerState> eligibleStates) {
+    Set<HeldContainer> containers = idleTracker.getByLocation(location);
+    HeldContainer bestMatch = null;
+    if (containers != null && !containers.isEmpty()) {
+      for (HeldContainer hc : containers) {
+        if (eligibleStates.contains(hc.getState())) {
+          Object csig = hc.getSignature();
+          if (csig == null || signatureMatcher.isSuperSet(csig, request.getContainerSignature())) {
+            int numAffinities = hc.getNumAffinities();
+            if (numAffinities == 0) {
+              bestMatch = hc;
+              break;
+            }
+            if (bestMatch == null || numAffinities < bestMatch.getNumAffinities()) {
+              bestMatch = hc;
+            }
+          } else {
+            LOG.debug("Unable to assign task {} to container {} due to signature mismatch", request.getTask(), hc.getId());
+          }
+        }
+      }
+    }
+    if (bestMatch != null) {
+      assignContainer(request, bestMatch, location);
+    }
+    return bestMatch;
+  }
+
+  @Override
+  public void setShouldUnregister() {
+    shouldUnregister = true;
+  }
+
+  @Override
+  public boolean hasUnregistered() {
+    return hasUnregistered;
+  }
+
+  @Override
+  public synchronized void dagComplete() {
+    for (HeldContainer hc : sessionContainers) {
+      hc.resetMatchingLevel();
+    }
+    vertexDescendants = null;
+  }
+
+  @GuardedBy("this")
+  @Nullable
+  private Collection<ContainerId> maybePreempt(Resource freeResources) {
+    if (preemptionPercentage == 0 || numHeartbeats - lastPreemptionHeartbeat < numHeartbeatsBetweenPreemptions) {
+      return null;
+    }
+    if (!requestTracker.isPreemptionDeadlineExpired() && requestTracker.fitsHighestPriorityRequest(freeResources)) {
+      if (numHeartbeats % 50 == 1) {
+        LOG.info("Highest priority request fits in free resources {}", freeResources);
+      }
+      return null;
+    }
+
+    int numIdleContainers = idleTracker.getNumContainers();
+    if (numIdleContainers > 0) {
+      if (numHeartbeats % 50 == 1) {
+        LOG.info("Avoiding preemption since there are {} idle containers", numIdleContainers);
+      }
+      return null;
+    }
+
+    BitSet blocked = requestTracker.createVertexBlockedSet();
+    if (!blocked.intersects(assignedVertices)) {
+      if (numHeartbeats % 50 == 1) {
+        LOG.info("Avoiding preemption since there are no descendants of the highest priority requests running");
+      }
+      return null;
+    }
+
+    Resource preemptLeft = requestTracker.getAmountToPreempt(preemptionPercentage);
+    if (!resourceCalculator.anyAvailable(preemptLeft)) {
+      if (numHeartbeats % 50 == 1) {
+        LOG.info("Avoiding preemption since amount to preempt is {}", preemptLeft);
+      }
+      return null;
+    }
+
+    PriorityQueue<HeldContainer> candidates = new PriorityQueue<>(11, PREEMPT_ORDER_COMPARATOR);
+    blocked.and(assignedVertices);
+    for (int i = blocked.nextSetBit(0); i >= 0; i = blocked.nextSetBit(i + 1)) {
+      Collection<HeldContainer> containers = vertexAssignments.get(i);
+      if (containers != null) {
+        candidates.addAll(containers);
+      } else {
+        LOG.error("Vertex {} in assignedVertices but no assignments?", i);
+      }
+    }
+
+    ArrayList<ContainerId> preemptedContainers = new ArrayList<>();
+    HeldContainer hc;
+    while ((hc = candidates.poll()) != null) {
+      LOG.info("Preempting container {} currently allocated to task {}", hc.getId(), hc.getAssignedTask());
+      preemptedContainers.add(hc.getId());
+      resourceCalculator.deductFrom(preemptLeft, hc.getCapability());
+      if (!resourceCalculator.anyAvailable(preemptLeft)) {
+        break;
+      }
+    }
+
+    return preemptedContainers;
+  }
+
+  @GuardedBy("this")
+  private String constructPeriodicLog(Resource freeResource) {
+    Priority highestPriority = requestTracker.getHighestPriority();
+    return "Allocated: " + allocatedResources +
+        " Free: " + freeResource +
+        " pendingRequests: " + requestTracker.getNumRequests() +
+        " heldContainers: " + heldContainers.size() +
+        " heartbeats: " + numHeartbeats +
+        " lastPreemptionHeartbeat: " + lastPreemptionHeartbeat +
+        ((highestPriority != null) ?
+            (" highestWaitingRequestWaitStartTime: " + requestTracker.getHighestPriorityWaitTimestamp() +
+                " highestWaitingRequestPriority: " + highestPriority) : "");
+  }
+
+  @VisibleForTesting
+  int getNumBlacklistedNodes() {
+    return blacklistedNodes.size();
+  }
+
+  @VisibleForTesting
+  Collection<HeldContainer> getSessionContainers() {
+    return sessionContainers;
+  }
+
+  // Wrapper class to work around lack of blacklisting APIs in async client.
+  // This can be removed once Tez requires YARN >= 2.7.0
+  static class AMRMClientAsyncWrapper extends AMRMClientAsyncImpl<TaskRequest> {
+    AMRMClientAsyncWrapper(AMRMClient<TaskRequest> syncClient, int intervalMs, CallbackHandler handler) {
+      super(syncClient, intervalMs, handler);
+    }
+
+    public void updateBlacklist(List<String> additions, List<String> removals) {
+      client.updateBlacklist(additions, removals);
+    }
+  }
+
+  /**
+   * A utility class to track a task allocation.
+   */
+  static class TaskRequest extends AMRMClient.ContainerRequest {
+    final Object task;
+    final int vertexIndex;
+    final Object signature;
+    final Object cookie;
+    final ContainerId affinityContainerId;
+
+    TaskRequest(Object task, int vertexIndex, Resource capability, String[] hosts, String[] racks,
+        Priority priority, Object signature, Object cookie) {
+      this(task, vertexIndex, capability, hosts, racks, priority, signature, cookie,  null);
+    }
+
+    TaskRequest(Object task, int vertexIndex, Resource capability, String[] hosts, String[] racks,
+        Priority priority, Object signature, Object cookie, ContainerId affinityContainerId) {
+      super(capability, hosts, racks, priority);
+      this.task = task;
+      this.vertexIndex = vertexIndex;
+      this.signature = signature;
+      this.cookie = cookie;
+      this.affinityContainerId = affinityContainerId;
+    }
+
+    Object getTask() {
+      return task;
+    }
+
+    int getVertexIndex() {
+      return vertexIndex;
+    }
+
+    Object getContainerSignature() {
+      return signature;
+    }
+
+    Object getCookie() {
+      return cookie;
+    }
+
+    @Nullable
+    ContainerId getAffinity() {
+      return affinityContainerId;
+    }
+
+    boolean hasLocality() {
+      List<String> nodes = getNodes();
+      List<String> racks = getRacks();
+      return (nodes != null && !nodes.isEmpty()) || (racks != null && !racks.isEmpty());
+    }
+  }
+
+  private enum HeldContainerState {
+    MATCHING_LOCAL(true),
+    MATCHING_RACK(true),
+    MATCHING_ANY(true),
+    ASSIGNED(false),
+    RELEASED(false);
+
+    private static final EnumSet<HeldContainerState> MATCHES_LOCAL_STATES = EnumSet.of(
+        HeldContainerState.MATCHING_LOCAL, HeldContainerState.MATCHING_RACK, HeldContainerState.MATCHING_ANY);
+    private static final EnumSet<HeldContainerState> MATCHES_RACK_STATES = EnumSet.of(
+        HeldContainerState.MATCHING_RACK, HeldContainerState.MATCHING_ANY);
+    private static final EnumSet<HeldContainerState> MATCHES_ANY_STATES = EnumSet.of(HeldContainerState.MATCHING_ANY);
+
+    private final boolean assignable;
+
+    HeldContainerState(boolean assignable) {
+      this.assignable = assignable;
+    }
+
+    boolean isAssignable() {
+      return assignable;
+    }
+  }
+
+  /**
+   * Tracking for an allocated container.
+   */
+  @VisibleForTesting
+  class HeldContainer implements Callable<Void> {
+    final Container container;
+    final String rack;
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    HeldContainerState state = HeldContainerState.MATCHING_LOCAL;
+
+    /** The Future received when scheduling an idle container for re-allocation at a later time. */
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    Future<Void> future = null;
+
+    /** The collection of task requests that have specified this container as a scheduling affinity. */
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    Collection<TaskRequest> affinities = null;
+
+    /**
+     * The task request corresponding to the currently assigned task to this container.
+     * This field is null when the container is not currently assigned.
+     */
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    TaskRequest assignedRequest = null;
+
+    /** The task request corresponding to the last task that was assigned to this container. */
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    TaskRequest lastRequest = null;
+
+    /** The timestamp when the idle container will expire. 0 if the container is not idle. */
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    long idleExpirationTimestamp = 0;
+
+    /** The timestamp when this container was assigned. 0 if the container is not assigned. */
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    long assignmentTimestamp = 0;
+
+    HeldContainer(Container container) {
+      this.container = container;
+      this.rack = RackResolver.resolve(container.getNodeId().getHost()).getNetworkLocation();
+    }
+
+    HeldContainerState getState() {
+      return state;
+    }
+
+    boolean isAssignable() {
+      return state.isAssignable();
+    }
+
+    boolean isReleasedAndUsed() {
+      return state == HeldContainerState.RELEASED && getLastTask() != null;
+    }
+
+    Container getContainer() {
+      return container;
+    }
+
+    ContainerId getId() {
+      return container.getId();
+    }
+
+    String getHost() {
+      return container.getNodeId().getHost();
+    }
+
+    String getRack() {
+      return rack;
+    }
+
+    Priority getPriority() {
+      return container.getPriority();
+    }
+
+    Resource getCapability() {
+      return container.getResource();
+    }
+
+    @Nullable
+    Object getAssignedTask() {
+      return assignedRequest != null ? assignedRequest.getTask() : null;
+    }
+
+    void assignTask(TaskRequest request) {
+      assert state != HeldContainerState.ASSIGNED && state != HeldContainerState.RELEASED;
+      if (assignedRequest != null) {
+        LOG.error("Container {} assigned task {} but already running task {}",
+            getId(), request.getTask(), assignedRequest.getTask());
+      }
+      assignedRequest = request;
+      lastRequest = request;
+      state = HeldContainerState.ASSIGNED;
+      idleExpirationTimestamp = 0;
+      assignmentTimestamp = now();
+      if (future != null) {
+        future.cancel(false);
+        future = null;
+      }
+    }
+
+    TaskRequest removeAssignment() {
+      assert state == HeldContainerState.ASSIGNED;
+      TaskRequest result = assignedRequest;
+      assignedRequest = null;
+      assignmentTimestamp = 0;
+      state = HeldContainerState.MATCHING_LOCAL;
+      return result;
+    }
+
+    void addAffinity(TaskRequest request) {
+      if (affinities == null) {
+        affinities = new HashSet<>();
+      }
+      affinities.add(request);
+    }
+
+    void removeAffinity(TaskRequest request) {
+      if (affinities != null && affinities.remove(request) && affinities.isEmpty()) {
+        affinities = null;
+      }
+    }
+
+    int getNumAffinities() {
+      return affinities != null ? affinities.size() : 0;
+    }
+
+    @Nullable
+    Collection<TaskRequest> getAffinities() {
+      return affinities;
+    }
+
+    void scheduleForReuse(long delayMillis) {
+      assert state != HeldContainerState.ASSIGNED && state != HeldContainerState.RELEASED;
+      try {
+        if (future != null) {
+          future.cancel(false);
+        }
+        future = reuseExecutor.schedule(this, delayMillis, TimeUnit.MILLISECONDS);
+      } catch (RejectedExecutionException e) {
+        if (!stopRequested) {
+          LOG.error("Container {} could not be scheduled for reuse!", getId(), e);
+        }
+      }
+    }
+
+    @Nullable
+    Object getSignature() {
+      return lastRequest != null ? lastRequest.getContainerSignature() : null;
+    }
+
+    @Nullable
+    Object getLastTask() {
+      return lastRequest != null ? lastRequest.getTask() : null;
+    }
+
+    String getMatchingLocation() {
+      switch (state) {
+      case MATCHING_LOCAL:
+        return getHost();
+      case MATCHING_RACK:
+        return getRack();
+      case MATCHING_ANY:
+        return ResourceRequest.ANY;
+      default:
+        throw new IllegalStateException("Container " + getId() + " trying to match in state " + state);
+      }
+    }
+
+    void moveToNextMatchingLevel() {
+      switch (state) {
+      case MATCHING_LOCAL:
+        if (reuseRackLocal) {
+          state = HeldContainerState.MATCHING_RACK;
+        }
+        break;
+      case MATCHING_RACK:
+        if (reuseNonLocal) {
+          state = HeldContainerState.MATCHING_ANY;
+        }
+        break;
+      case MATCHING_ANY:
+        break;
+      default:
+        throw new IllegalStateException("Container " + getId() + " trying to match in state " + state);
+      }
+    }
+
+    boolean atMaxMatchLevel() {
+      switch (state) {
+      case MATCHING_LOCAL:
+        return !reuseRackLocal;
+      case MATCHING_RACK:
+        return !reuseNonLocal;
+      case MATCHING_ANY:
+        return true;
+      default:
+        throw new IllegalStateException("Container " + getId() + " trying to match in state " + state);
+      }
+    }
+
+    void resetMatchingLevel() {
+      if (isAssignable()) {
+        state = HeldContainerState.MATCHING_LOCAL;
+      }
+    }
+
+    long getIdleExpirationTimestamp(long now) {
+      if (idleExpirationTimestamp == 0) {
+        if (idleContainerTimeoutMin > 0) {
+          idleExpirationTimestamp = now + random.nextLong(idleContainerTimeoutMin, idleContainerTimeoutMax);
+        } else {
+          idleExpirationTimestamp = Long.MAX_VALUE;
+        }
+      }
+      return idleExpirationTimestamp;
+    }
+
+    long getAssignmentTimestamp() {
+      return assignmentTimestamp;
+    }
+
+    boolean canFit(Resource capability) {
+      Resource cr = container.getResource();
+      return cr.getMemory() >= capability.getMemory() && cr.getVirtualCores() >= capability.getVirtualCores();
+    }
+
+    @Override
+    public Void call() throws Exception {
+      AMState appState = getContext().getAMState();
+      boolean isSession = getContext().isSession();
+      TaskRequest assigned = null;
+      ContainerId released = null;
+      synchronized (DagAwareYarnTaskScheduler.this) {
+        future = null;
+        if (isAssignable()) {
+          moveToNextMatchingLevel();
+          assigned = tryAssignReuseContainer(this, appState, isSession);
+          if (assigned == null && isReleasedAndUsed()) {
+            released = getId();
+          }
+        }
+      }
+      if (assigned != null) {
+        informAppAboutAssignment(assigned, container);
+      }
+      if (released != null) {
+        getContext().containerBeingReleased(released);
+      }
+      return null;
+    }
+
+    void released() {
+      assert state != HeldContainerState.RELEASED;
+      state = HeldContainerState.RELEASED;
+      if (future != null) {
+        future.cancel(false);
+      }
+      future = null;
+    }
+  }
+
+  /**
+   * Utility comparator to order containers by assignment timestamp from
+   * most recent to least recent.
+   */
+  private static class PreemptOrderComparator implements Comparator<HeldContainer> {
+    @Override
+    public int compare(HeldContainer o1, HeldContainer o2) {
+      long timestamp1 = o1.getAssignmentTimestamp();
+      if (timestamp1 == 0) {
+        timestamp1 = Long.MAX_VALUE;
+      }
+      long timestamp2 = o2.getAssignmentTimestamp();
+      if (timestamp2 == 0) {
+        timestamp2 = Long.MAX_VALUE;
+      }
+      return Long.compare(timestamp2, timestamp1);
+    }
+  }
+
+  /**
+   * Utility class for a request, container pair
+   */
+  private static class Assignment {
+    final TaskRequest request;
+    final Container container;
+
+    Assignment(TaskRequest request, Container container) {
+      this.request = request;
+      this.container = container;
+    }
+  }
+
+  /**
+   * Utility class for a task, container exit status pair
+   */
+  private static class TaskStatus {
+    final Object task;
+    final ContainerStatus status;
+
+    TaskStatus(Object task, ContainerStatus status) {
+      this.task = task;
+      this.status = status;
+    }
+  }
+
+  /**
+   * The task allocation request tracker tracks task allocations
+   * and keeps statistics on which priorities have requests and which vertices
+   * should be blocked from container reuse due to DAG topology.
+   */
+  private class RequestTracker {
+    private final Map<Object, TaskRequest> requests = new HashMap<>();
+    /** request map ordered by priority with highest priority first */
+    private final NavigableMap<Priority, RequestPriorityStats> priorityStats =
+        new TreeMap<>(Collections.reverseOrder());
+    private Priority highestPriority = null;
+    private long highestPriorityWaitTimestamp = 0;
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    @Nullable
+    TaskRequest add(TaskRequest request) {
+      TaskRequest oldRequest = requests.put(request.getTask(), request);
+      Priority priority = request.getPriority();
+      RequestPriorityStats stats = priorityStats.get(priority);
+      if (stats == null) {
+        stats = addStatsForPriority(priority);
+      }
+      ++stats.requestCount;
+      if (request.hasLocality()) {
+        ++stats.localityCount;
+      }
+      incrVertexTaskCount(priority, stats, request.getVertexIndex());
+
+      if (oldRequest != null) {
+        updateStatsForRemoval(oldRequest);
+      }
+      return oldRequest;
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    @Nullable
+    TaskRequest remove(Object task) {
+      TaskRequest request = requests.remove(task);
+      if (request != null) {
+        updateStatsForRemoval(request);
+        return request;
+      }
+      return null;
+    }
+
+    private RequestPriorityStats addStatsForPriority(Priority priority) {
+      BitSet allowedVerts = new BitSet(vertexDescendants.size());
+      Entry<Priority,RequestPriorityStats> lowerEntry = priorityStats.lowerEntry(priority);
+      if (lowerEntry != null) {
+        // initialize the allowed vertices BitSet using the information derived
+        // from the next higher priority entry
+        RequestPriorityStats priorStats = lowerEntry.getValue();
+        allowedVerts.or(priorStats.allowedVertices);
+        allowedVerts.andNot(priorStats.descendants);
+      } else {
+        // no higher priority entry so this priority is currently the highest
+        highestPriority = priority;
+        highestPriorityWaitTimestamp = now();
+        allowedVerts.set(0, vertexDescendants.size());
+      }
+      RequestPriorityStats stats = new RequestPriorityStats(vertexDescendants.size(), allowedVerts);
+      priorityStats.put(priority, stats);
+      return stats;
+    }
+
+    private void updateStatsForRemoval(TaskRequest request) {
+      Priority priority = request.getPriority();
+      RequestPriorityStats stats = priorityStats.get(priority);
+      decrVertexTaskCount(priority, stats, request.getVertexIndex());
+      --stats.requestCount;
+      if (request.hasLocality()) {
+        --stats.localityCount;
+      }
+      if (stats.requestCount == 0) {
+        priorityStats.remove(priority);
+        if (highestPriority.equals(priority)) {
+          if (priorityStats.isEmpty()) {
+            highestPriority = null;
+            highestPriorityWaitTimestamp = 0;
+          } else {
+            highestPriority = priorityStats.firstKey();
+            highestPriorityWaitTimestamp = now();
+          }
+        }
+      }
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    boolean isEmpty() {
+      return requests.isEmpty();
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    int getNumRequests() {
+      return requests.size();
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    List<Object> getTasks() {
+      return new ArrayList<>(requests.keySet());
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    Collection<Entry<Priority, RequestPriorityStats>> getStatsEntries() {
+      return priorityStats.entrySet();
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    @Nullable
+    Priority getHighestPriority() {
+      if (priorityStats.isEmpty()) {
+        return null;
+      }
+      return priorityStats.firstKey();
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    long getHighestPriorityWaitTimestamp() {
+      return highestPriorityWaitTimestamp;
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    boolean isRequestBlocked(TaskRequest request) {
+      Entry<Priority, RequestPriorityStats> entry = priorityStats.floorEntry(request.getPriority());
+      if (entry != null) {
+        RequestPriorityStats stats = entry.getValue();
+        int vertexIndex = request.getVertexIndex();
+        return !stats.allowedVertices.get(vertexIndex) || stats.descendants.get(vertexIndex);
+      }
+      return false;
+    }
+
+    private void incrVertexTaskCount(Priority priority, RequestPriorityStats stats, int vertexIndex) {
+      Integer vertexIndexInt = vertexIndex;
+      MutableInt taskCount = stats.vertexTaskCount.get(vertexIndexInt);
+      if (taskCount != null) {
+        taskCount.increment();
+      } else {
+        addVertexToRequestStats(priority, stats, vertexIndexInt);
+      }
+    }
+
+    private void decrVertexTaskCount(Priority priority, RequestPriorityStats stats, int vertexIndex) {
+      Integer vertexIndexInt = vertexIndex;
+      MutableInt taskCount = stats.vertexTaskCount.get(vertexIndexInt);
+      taskCount.decrement();
+      if (taskCount.intValue() <= 0) {
+        removeVertexFromRequestStats(priority, stats, vertexIndexInt);
+      }
+    }
+
+    /**
+     * Add a new vertex to a RequestPriorityStats.
+     *
+     * Adding a vertex to the request stats requires updating the stats descendants bitmask to include the descendants
+     * of the new vertex and also updating the allowedVertices bitmask for all lower priority requests to prevent any
+     * task request from a descendant vertex in the DAG from being allocated. This avoids assigning allocations to
+     * lower priority requests when a higher priority request of an ancestor is still pending, but it allows lower
+     * priority requests to be satisfied if higher priority requests are not ancestors. This is particularly useful
+     * for DAGs that have independent trees of vertices or significant, parallel branches within a tree.
+     *
+     * Requests are blocked by taking the specified vertex's full descendant vertex bitmask in vertexDescendants and
+     * clearing those bits for all lower priority requests. For the following example DAG where each vertex index
+     * corresponds to its letter position (i.e.: A=0, B=1, C=2, etc.)
+     *
+     *       A
+     *       |
+     *   C---B----E
+     *   |        |
+     *   D        F
+     *            |
+     *          G---H
+     *
+     * Vertices F, G, and H are descendants of E but all other vertices are not. The vertexDescendants bitmask for
+     * vertex E is therefore 11100000b or 0xE0. When the first vertex E task request arrives we need to disallow
+     * requests for all descendants of E. That is accomplished by iterating through the request stats for all lower
+     * priority requests and clearing the allowedVertex bits corresponding to the descendants,
+     * i.e: allowedVertices = allowedVertices & ~descendants
+     */
+    private void addVertexToRequestStats(Priority priority, RequestPriorityStats stats, Integer vertexIndexInt) {
+      // Creating a new vertex entry for this priority, so the allowed vertices for all
+      // lower priorities need to be updated based on the descendants of the new vertex.
+      stats.vertexTaskCount.put(vertexIndexInt, new MutableInt(1));
+      int vertexIndex = vertexIndexInt;
+      stats.vertices.set(vertexIndex);
+      BitSet d = vertexDescendants.get(vertexIndex);
+      stats.descendants.or(d);
+      for (RequestPriorityStats lowerStat : priorityStats.tailMap(priority, false).values()) {
+        lowerStat.allowedVertices.andNot(d);
+      }
+    }
+
+    /**
+     * Removes a vertex from a RequestPriorityStats.
+     *
+     * Removing a vertex is more expensive than adding a vertex. The stats contain bitmasks which only store on/off
+     * values rather than reference counts. Therefore we must rebuild the descendants bitmasks from the remaining
+     * vertices in the request stats. Once the new descendants mask is computed we then need to rebuild the
+     * allowedVertices BitSet for all lower priority request stats in case the removal of this vertex unblocks lower
+     * priority requests of a descendant vertex.
+     *
+     * Rebuilding allowedVertices for the lower priorities involves starting with the allowedVertices mask at the
+     * current priority then masking off the descendants at each priority level encountered, accumulating the results.
+     * Any descendants of a level will be blocked at all lower levels. See the addVertexToRequestStats documentation
+     * for details on how vertices map to the descendants and allowedVertices bit masks.
+     */
+    private void removeVertexFromRequestStats(Priority priority, RequestPriorityStats stats, Integer vertexIndexInt) {
+      stats.vertexTaskCount.remove(vertexIndexInt);
+      int vertexIndex = vertexIndexInt;
+      stats.vertices.clear(vertexIndex);
+
+      // Rebuild the descendants BitSet for the remaining vertices at this priority.
+      stats.descendants.clear();
+      for (Integer vIndex : stats.vertexTaskCount.keySet()) {
+        stats.descendants.or(vertexDescendants.get(vIndex));
+      }
+
+      // The allowedVertices for all lower priorities need to be recalculated where the vertex descendants at each
+      // level are removed from the list of allowed vertices at all subsequent levels.
+      Collection<RequestPriorityStats> tailStats = priorityStats.tailMap(priority, false).values();
+      if (!tailStats.isEmpty()) {
+        BitSet cumulativeAllowed = new BitSet(vertexDescendants.size());
+        cumulativeAllowed.or(stats.allowedVertices);
+        cumulativeAllowed.andNot(stats.descendants);
+        for (RequestPriorityStats s : tailStats) {
+          s.allowedVertices.clear();
+          s.allowedVertices.or(cumulativeAllowed);
+          cumulativeAllowed.andNot(s.descendants);
+        }
+      }
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    boolean isPreemptionDeadlineExpired() {
+      return highestPriorityWaitTimestamp != 0
+          && now() - highestPriorityWaitTimestamp > preemptionMaxWaitTime;
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    boolean fitsHighestPriorityRequest(Resource freeResources) {
+      if (priorityStats.isEmpty()) {
+        return true;
+      }
+      Priority priority = priorityStats.firstKey();
+      List<? extends Collection> requestsList = client.getMatchingRequests(
+          priority, ResourceRequest.ANY, freeResources);
+      return !requestsList.isEmpty();
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    Resource getAmountToPreempt(int preemptionPercentage) {
+      if (priorityStats.isEmpty()) {
+        return Resources.none();
+      }
+      Priority priority = priorityStats.firstKey();
+      List<? extends Collection<TaskRequest>> requestsList = client.getMatchingRequests(
+          priority, ResourceRequest.ANY, Resources.unbounded());
+      int numRequests = 0;
+      for (Collection<TaskRequest> requests : requestsList) {
+        numRequests += requests.size();
+      }
+      numRequests = (int) Math.ceil(numRequests * (preemptionPercentage / 100.f));
+      Resource toPreempt = Resource.newInstance(0, 0);
+      if (numRequests != 0) {
+        outer_loop:
+        for (Collection<TaskRequest> requests : requestsList) {
+          for (TaskRequest request : requests) {
+            Resources.addTo(toPreempt, request.getCapability());
+            if (--numRequests == 0) {
+              break outer_loop;
+            }
+          }
+        }
+      }
+      return toPreempt;
+    }
+
+    // Create a new BitSet that represents all of the vertices that should not be
+    // scheduled due to outstanding requests from higher priority predecessor vertices.
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    BitSet createVertexBlockedSet() {
+      BitSet blocked = new BitSet();
+      Entry<Priority, RequestPriorityStats> entry = priorityStats.lastEntry();
+      if (entry != null) {
+        RequestPriorityStats stats = entry.getValue();
+        blocked.or(stats.allowedVertices);
+        blocked.flip(0, blocked.length());
+        blocked.or(stats.descendants);
+      }
+      return blocked;
+    }
+  }
+
+  /**
+   * Tracks statistics on vertices that are requesting tasks at a particular priority
+   */
+  private static class RequestPriorityStats {
+    /** Map from vertex ID to number of task requests for that vertex */
+    final Map<Integer, MutableInt> vertexTaskCount = new HashMap<>();
+    /** BitSet of vertices that have oustanding requests at this priority */
+    final BitSet vertices;
+    /** BitSet of vertices that are descendants of this vertex */
+    final BitSet descendants;
+    /**
+     * BitSet of vertices that are allowed to be scheduled at this priority
+     * (i.e.: no oustanding predecessors requesting at higher priorities)
+     */
+    final BitSet allowedVertices;
+    int requestCount = 0;
+    int localityCount = 0;
+
+    RequestPriorityStats(int numTotalVertices, BitSet allowedVertices) {
+      this.vertices = new BitSet(numTotalVertices);
+      this.descendants = new BitSet(numTotalVertices);
+      this.allowedVertices = allowedVertices;
+    }
+  }
+
+  /**
+   * Tracks idle containers and facilitates faster matching of task requests
+   * against those containers given a desired location.
+   */
+  private static class IdleContainerTracker {
+    /**
+     * Map of location ID (e.g.: a specific host, rack, or ANY) to set of
+     * idle containers matching that location
+     */
+    final Map<String, Set<HeldContainer>> containersByLocation = new HashMap<>();
+    int numContainers = 0;
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    void add(HeldContainer hc) {
+      add(hc, hc.getHost());
+      add(hc, hc.getRack());
+      add(hc, ResourceRequest.ANY);
+      ++numContainers;
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    void remove(HeldContainer hc) {
+      remove(hc, hc.getHost());
+      remove(hc, hc.getRack());
+      remove(hc, ResourceRequest.ANY);
+      --numContainers;
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    int getNumContainers() {
+      return numContainers;
+    }
+
+    private void add(HeldContainer hc, String location) {
+      Set<HeldContainer> containers = containersByLocation.get(location);
+      if (containers == null) {
+        containers = new HashSet<>();
+        containersByLocation.put(location, containers);
+      }
+      containers.add(hc);
+    }
+
+    private void remove(HeldContainer hc, String location) {
+      Set<HeldContainer> containers = containersByLocation.get(location);
+      if (containers != null) {
+        if (containers.remove(hc) && containers.isEmpty()) {
+          containersByLocation.remove(location);
+        }
+      }
+    }
+
+    @GuardedBy("DagAwareYarnTaskScheduler.this")
+    @Nullable
+    Set<HeldContainer> getByLocation(String location) {
+      return containersByLocation.get(location);
+    }
+  }
+
+  private interface ResourceCalculator {
+    boolean anyAvailable(Resource rsrc);
+    void deductFrom(Resource total, Resource toSubtract);
+  }
+
+  /**
+   * ResourceCalculator for memory-only allocation
+   */
+  private static class MemResourceCalculator implements ResourceCalculator {
+
+    @Override
+    public boolean anyAvailable(Resource rsrc) {
+      return rsrc.getMemory() > 0;
+    }
+
+    @Override
+    public void deductFrom(Resource total, Resource toSubtract) {
+      total.setMemory(total.getMemory() - toSubtract.getMemory());
+    }
+  }
+
+  /**
+   * ResourceCalculator for memory and vcore allocation
+   */
+  private static class MemCpuResourceCalculator extends MemResourceCalculator {
+
+    @Override
+    public boolean anyAvailable(Resource rsrc) {
+      return super.anyAvailable(rsrc) || rsrc.getVirtualCores() > 0;
+    }
+
+    @Override
+    public void deductFrom(Resource total, Resource toSubtract) {
+      super.deductFrom(total, toSubtract);
+      total.setVirtualCores(total.getVirtualCores() - toSubtract.getVirtualCores());
+    }
+  }
+
+  /**
+   * Scheduled thread pool executor that logs any errors that escape the worker thread.
+   * This can be replaced with HadoopThreadPoolExecutor once Tez requires Hadoop 2.8 or later.
+   */
+  static class ReuseContainerExecutor extends ScheduledThreadPoolExecutor {
+    ReuseContainerExecutor() {
+      super(1, new ThreadFactoryBuilder().setNameFormat("ReuseContainerExecutor #%d").build());
+      setRemoveOnCancelPolicy(true);
+      setExecuteExistingDelayedTasksAfterShutdownPolicy(false);
+    }
+
+    @Override
+    protected void afterExecute(Runnable r, Throwable t) {
+      super.afterExecute(r, t);
+
+      if (t == null && r instanceof Future<?>) {
+        try {
+          ((Future<?>) r).get();
+        } catch (ExecutionException ee) {
+          LOG.warn("Execution exception when running task in {}",  Thread.currentThread().getName());
+          t = ee.getCause();
+        } catch (InterruptedException ie) {
+          LOG.warn("Thread ({}) interrupted: ", Thread.currentThread(), ie);
+          Thread.currentThread().interrupt();
+        } catch (Throwable throwable) {
+          t = throwable;
+        }
+      }
+
+      if (t != null) {
+        LOG.warn("Caught exception in thread {}", Thread.currentThread().getName(), t);
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/a9b8bb5a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerContextImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerContextImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerContextImpl.java
index 39000d6..f273c8b 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerContextImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerContextImpl.java
@@ -183,6 +183,11 @@ public class TaskSchedulerContextImpl implements TaskSchedulerContext {
   }
 
   @Override
+  public int getVertexIndexForTask(Object task) {
+    return taskSchedulerManager.getVertexIndexForTask(task);
+  }
+
+  @Override
   public void reportError(ServicePluginError servicePluginError, String diagnostics,
                           DagInfo dagInfo) {
     Preconditions.checkNotNull(servicePluginError, "ServicePluginError must be specified");


Mime
View raw message