tez-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ss...@apache.org
Subject tez git commit: TEZ-2149. Optimizations for the timed version of DAGClient.getStatus. (sseth)
Date Tue, 31 Mar 2015 04:33:20 GMT
Repository: tez
Updated Branches:
  refs/heads/master 008f9bc1e -> 60f413d9b


TEZ-2149. Optimizations for the timed version of DAGClient.getStatus. (sseth)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/60f413d9
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/60f413d9
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/60f413d9

Branch: refs/heads/master
Commit: 60f413d9bbe4a415efc48acf5808995b1787fc53
Parents: 008f9bc
Author: Siddharth Seth <sseth@apache.org>
Authored: Mon Mar 30 21:32:52 2015 -0700
Committer: Siddharth Seth <sseth@apache.org>
Committed: Mon Mar 30 21:32:52 2015 -0700

----------------------------------------------------------------------
 .../apache/tez/dag/api/client/DAGClient.java    |   1 -
 .../tez/dag/api/client/DAGClientImpl.java       | 120 ++++++---
 .../dag/api/client/DAGClientTimelineImpl.java   |   2 +-
 .../apache/tez/dag/api/client/DAGStatus.java    |  17 +-
 .../tez/dag/api/client/DagStatusSource.java     |  22 ++
 .../dag/api/client/rpc/DAGClientRPCImpl.java    |   3 +-
 .../tez/dag/api/client/rpc/TestDAGClient.java   | 257 +++++++++++++++++--
 .../tez/dag/api/client/DAGStatusBuilder.java    |   2 +-
 .../apache/tez/dag/app/dag/impl/DAGImpl.java    |  68 +++--
 .../tez/dag/app/dag/impl/TestDAGImpl.java       | 174 +++++++++++++
 10 files changed, 572 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClient.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClient.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClient.java
index 13c8ce6..27b316b 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClient.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClient.java
@@ -71,7 +71,6 @@ public abstract class DAGClient implements Closeable {
    * @throws IOException
    * @throws TezException
    */
-  @Unstable
   public abstract DAGStatus getDAGStatus(@Nullable Set<StatusGetOpts> statusOptions,
       long timeout)
       throws IOException, TezException;

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientImpl.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientImpl.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientImpl.java
index dd83ecc..fac1d36 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientImpl.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientImpl.java
@@ -29,6 +29,7 @@ import java.util.Set;
 
 import com.google.common.annotations.VisibleForTesting;
 
+import com.google.common.base.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.apache.hadoop.classification.InterfaceAudience.Private;
@@ -116,54 +117,72 @@ public class DAGClientImpl extends DAGClient {
   @Override
   public DAGStatus getDAGStatus(@Nullable Set<StatusGetOpts> statusOptions,
       final long timeout) throws TezException, IOException {
-    long currentStatusPollInterval = statusPollInterval;
-    if(timeout >= 0 && currentStatusPollInterval > timeout) {
-      currentStatusPollInterval = timeout;
+
+    Preconditions.checkArgument(timeout >= -1, "Timeout must be >= -1");
+    // Short circuit a timeout of 0.
+    if (timeout == 0) {
+      return getDAGStatusInternal(statusOptions, timeout);
     }
-    DAGStatus dagStatus = null;
+
+    long startTime = System.currentTimeMillis();
+    boolean refreshStatus;
+    DAGStatus dagStatus;
     if(cachedDagStatus != null) {
       dagStatus = cachedDagStatus;
+      refreshStatus = true;
     } else {
+      // For the first lookup only. After this cachedDagStatus should be populated.
       dagStatus = getDAGStatus(statusOptions);
+      refreshStatus = false;
     }
-    //Handling when client dag status init or submitted
+
+    // Handling when client dag status init or submitted. This really implies that the RM
was
+    // contacted to get status. INITING is never used. DAG_INITING implies a DagState of
RUNNING.
     if (dagStatus.getState() == DAGStatus.State.INITING
         || dagStatus.getState() == DAGStatus.State.SUBMITTED) {
-      boolean initOrSubmittedState = true;
-      long timeoutTime = System.currentTimeMillis() + timeout;
+      long timeoutAbsolute = startTime + timeout;
       while (timeout < 0
-          || (timeout > 0 && timeoutTime > System.currentTimeMillis())) {
-        if (initOrSubmittedState) {
-          dagStatus = getDAGStatus(statusOptions);
+          || (timeout > 0 && timeoutAbsolute > System.currentTimeMillis()))
{
+        if (refreshStatus) {
+          // Try fetching the state with a timeout, in case the AM is already up.
+          dagStatus = getDAGStatusInternal(statusOptions, timeout);
         }
+        refreshStatus = true; // For the next iteration of the loop.
+
         if (dagStatus.getState() == DAGStatus.State.RUNNING) {
-          initOrSubmittedState = false;
-          // When RUNNING State, Check for AM status is also RUNNING
-          DAGStatus dagStatusFromAM = getDAGStatusViaAM(statusOptions, 0);
-          if (dagStatusFromAM != null) {
-            if (dagStatusFromAM.getState() == DAGStatus.State.RUNNING) {
-              long remainingTimeout = 0;
-              if (timeout <= 0) {
-                remainingTimeout = timeout;
-              } else {
-                if (timeoutTime > System.currentTimeMillis()) {
-                  remainingTimeout = timeoutTime - System.currentTimeMillis();
-                } else {
-                  return dagStatusFromAM;
-                }
-              }
-              dagStatus = getDAGStatusInternal(statusOptions, remainingTimeout);
-            } else {
-              dagStatus = dagStatusFromAM;
-            }
-            break;
+          // Refreshed status indicates that the DAG is running.
+          // This status could have come from the AM or the RM - client sleep if RM, otherwise
send request to the AM.
+          if (dagStatus.getSource() == DagStatusSource.AM) {
+            // RUNNING + AM should only happen if timeout is > -1.
+            // Otherwise the AM ignored the -1 value, or the AM source in the DAGStatus is
invalid.
+            Preconditions.checkState(timeout > -1, "Should not reach here with a timeout
of -1. File a bug");
+            return dagStatus;
+          } else {
+            // From the RM. Fall through to the Sleep.
           }
-        }
-        if(dagStatus.getState() == DAGStatus.State.SUCCEEDED
+        } else if(dagStatus.getState() == DAGStatus.State.SUCCEEDED
             || dagStatus.getState() == DAGStatus.State.FAILED
             || dagStatus.getState() == DAGStatus.State.KILLED
             || dagStatus.getState() == DAGStatus.State.ERROR) {
-          break;
+          // Again, check if this was from the RM. If it was, try getting it from a more
informative source.
+          if (dagStatus.getSource() == DagStatusSource.RM) {
+            return getDAGStatusInternal(statusOptions, 0);
+          } else {
+            return dagStatus;
+          }
+        }
+        // Sleep before checking again.
+        long currentStatusPollInterval;
+        if (timeout < 0) {
+          currentStatusPollInterval = statusPollInterval;
+        } else {
+          long remainingTimeout = timeoutAbsolute - System.currentTimeMillis();
+          if (remainingTimeout < 0) {
+            // Timeout expired. Return the latest known dag status.
+            return dagStatus;
+          } else {
+            currentStatusPollInterval = remainingTimeout < statusPollInterval ? remainingTimeout
: statusPollInterval;
+          }
         }
         try {
           Thread.sleep(currentStatusPollInterval);
@@ -171,8 +190,13 @@ public class DAGClientImpl extends DAGClient {
           throw new TezException(e);
         }
       }// End of while
-      return dagStatus;
-    } else {
+      // Timeout may have expired before a single refresh
+      if (refreshStatus) {
+        return getDAGStatus(statusOptions);
+      } else {
+        return dagStatus;
+      }
+    } else { // Already running, or complete. Fallback to regular dagStatus with a timeout.
       return getDAGStatusInternal(statusOptions, timeout);
     }
   }
@@ -184,6 +208,8 @@ public class DAGClientImpl extends DAGClient {
       // fetch from AM. on Error and while DAG is still not completed (could not reach AM,
AM got
       // killed). return cached status. This prevents the progress being reset (for ex fetching
from
       // RM does not give status).
+
+      // dagCompleted may be reset within getDagStatusViaAM
       final DAGStatus dagStatus = getDAGStatusViaAM(statusOptions, timeout);
 
       if (!dagCompleted) {
@@ -306,6 +332,13 @@ public class DAGClientImpl extends DAGClient {
     }
   }
 
+  /**
+   * Get the DAG status via the AM
+   * @param statusOptions
+   * @param timeout
+   * @return null if the AM cannot be contacted, otherwise the DAGstatus
+   * @throws IOException
+   */
   private DAGStatus getDAGStatusViaAM(@Nullable Set<StatusGetOpts> statusOptions,
       long timeout) throws IOException {
     DAGStatus dagStatus = null;
@@ -342,7 +375,14 @@ public class DAGClientImpl extends DAGClient {
     return vertexStatus;
   }
 
-  DAGStatus getDAGStatusViaRM() throws TezException, IOException {
+  /**
+   * Get the DAG status via the YARN ResourceManager
+   * @return the dag status, inferred from the RM App state. Does not return null.
+   * @throws TezException
+   * @throws IOException
+   */
+  @VisibleForTesting
+  protected DAGStatus getDAGStatusViaRM() throws TezException, IOException {
     if(LOG.isDebugEnabled()) {
       LOG.debug("GetDAGStatus via AM for app: " + appId + " dag:" + dagId);
     }
@@ -358,7 +398,7 @@ public class DAGClientImpl extends DAGClient {
     }
 
     DAGProtos.DAGStatusProto.Builder builder = DAGProtos.DAGStatusProto.newBuilder();
-    DAGStatus dagStatus = new DAGStatus(builder);
+    DAGStatus dagStatus = new DAGStatus(builder, DagStatusSource.RM);
     DAGProtos.DAGStatusStateProto dagState;
     switch (appReport.getYarnApplicationState()) {
       case NEW:
@@ -416,7 +456,7 @@ public class DAGClientImpl extends DAGClient {
     double dagProgress = -1.0; // Print the first one
     // monitoring
     while (true) {
-      dagStatus = getDAGStatus(statusGetOpts);
+      dagStatus = getDAGStatus(statusGetOpts, SLEEP_FOR_COMPLETION);
       if (!initPrinted
           && (dagStatus.getState() == DAGStatus.State.INITING || dagStatus.getState()
== DAGStatus.State.SUBMITTED)) {
         initPrinted = true; // Print once
@@ -429,7 +469,6 @@ public class DAGClientImpl extends DAGClient {
           || dagStatus.getState() == DAGStatus.State.ERROR) {
         break;
       }
-      Thread.sleep(SLEEP_FOR_COMPLETION);
     }// End of while(true)
 
     Set<String> vertexNames = Collections.emptySet();
@@ -442,8 +481,7 @@ public class DAGClientImpl extends DAGClient {
         vertexNames = getDAGStatus(statusGetOpts).getVertexProgress().keySet();
       }
       dagProgress = monitorProgress(vertexNames, dagProgress, null, dagStatus);
-      Thread.sleep(SLEEP_FOR_COMPLETION);
-      dagStatus = getDAGStatus(statusGetOpts);
+      dagStatus = getDAGStatus(statusGetOpts, SLEEP_FOR_COMPLETION);
     }// end of while
     // Always print the last status irrespective of progress change
     monitorProgress(vertexNames, -1.0, statusGetOpts, dagStatus);

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientTimelineImpl.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientTimelineImpl.java
b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientTimelineImpl.java
index fe4b033..4a5a4e2 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientTimelineImpl.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGClientTimelineImpl.java
@@ -144,7 +144,7 @@ public class DAGClientTimelineImpl extends DAGClient {
         throw new TezException("Failed to get DagStatus from ATS");
       }
 
-      return new DAGStatus(statusBuilder);
+      return new DAGStatus(statusBuilder, DagStatusSource.TIMELINE);
     } catch (JSONException je) {
       throw new TezException("Failed to parse DagStatus json from YARN Timeline", je);
     }

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
index d079da2..7e48334 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/DAGStatus.java
@@ -24,6 +24,7 @@ import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 import org.apache.commons.lang.StringUtils;
+import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.tez.common.counters.TezCounters;
 import org.apache.tez.dag.api.DAG;
@@ -42,8 +43,8 @@ public class DAGStatus {
     .getProperty("line.separator");
 
   public enum State {
-    SUBMITTED,
-    INITING,
+    SUBMITTED, // Returned from the RM only
+    INITING, // This is currently never returned. DAG_INITING is treated as RUNNING.
     RUNNING,
     SUCCEEDED,
     KILLED,
@@ -51,6 +52,7 @@ public class DAGStatus {
     ERROR,
   }
 
+  final DagStatusSource source;
   DAGStatusProtoOrBuilder proxy = null;
   Progress progress = null;
   // use LinkedHashMap to ensure the vertex order (TEZ-1065)
@@ -58,8 +60,10 @@ public class DAGStatus {
   TezCounters dagCounters = null;
   AtomicBoolean countersInitialized = new AtomicBoolean(false);
 
-  public DAGStatus(DAGStatusProtoOrBuilder proxy) {
+  @InterfaceAudience.Private
+  public DAGStatus(DAGStatusProtoOrBuilder proxy, DagStatusSource source) {
     this.proxy = proxy;
+    this.source = source;
   }
 
   public State getState() {
@@ -147,8 +151,14 @@ public class DAGStatus {
     return dagCounters;
   }
 
+  @InterfaceAudience.Private
+  DagStatusSource getSource() {
+    return this.source;
+  }
+
   @Override
   public boolean equals(Object obj) {
+    // Source explicitly exclude from equals
     if (obj instanceof DAGStatus){
       DAGStatus other = (DAGStatus)obj;
       return getState() == other.getState()
@@ -164,6 +174,7 @@ public class DAGStatus {
 
   @Override
   public int hashCode() {
+    // Source explicitly exclude from hashCode
     final int prime = 44017;
     int result = 1;
     result = prime +

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/main/java/org/apache/tez/dag/api/client/DagStatusSource.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/DagStatusSource.java b/tez-api/src/main/java/org/apache/tez/dag/api/client/DagStatusSource.java
new file mode 100644
index 0000000..58ed853
--- /dev/null
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/DagStatusSource.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.api.client;
+
+import org.apache.hadoop.classification.InterfaceAudience;
+
+@InterfaceAudience.Private
+public enum DagStatusSource {
+  AM, RM, TIMELINE
+}

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientRPCImpl.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientRPCImpl.java
b/tez-api/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientRPCImpl.java
index 27c54fc..223c0ab 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientRPCImpl.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/client/rpc/DAGClientRPCImpl.java
@@ -40,6 +40,7 @@ import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.client.DAGClient;
 import org.apache.tez.dag.api.client.DAGStatus;
+import org.apache.tez.dag.api.client.DagStatusSource;
 import org.apache.tez.dag.api.client.StatusGetOpts;
 import org.apache.tez.dag.api.client.VertexStatus;
 import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetDAGStatusRequestProto;
@@ -173,7 +174,7 @@ public class DAGClientRPCImpl extends DAGClient {
     try {
       return new DAGStatus(
         proxy.getDAGStatus(null,
-          requestProtoBuilder.build()).getDagStatus());
+          requestProtoBuilder.build()).getDagStatus(), DagStatusSource.AM);
     } catch (ServiceException e) {
       final Throwable cause = e.getCause();
       if (cause instanceof RemoteException) {

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-api/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClient.java
----------------------------------------------------------------------
diff --git a/tez-api/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClient.java b/tez-api/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClient.java
index c6894ef..5143a3c 100644
--- a/tez-api/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClient.java
+++ b/tez-api/src/test/java/org/apache/tez/dag/api/client/rpc/TestDAGClient.java
@@ -21,21 +21,27 @@ package org.apache.tez.dag.api.client.rpc;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import javax.annotation.Nullable;
 import java.io.IOException;
+import java.util.EnumSet;
+import java.util.Set;
 
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationReport;
 import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.tez.client.FrameworkClient;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.client.DAGClient;
 import org.apache.tez.dag.api.client.DAGClientImpl;
 import org.apache.tez.dag.api.client.DAGStatus;
+import org.apache.tez.dag.api.client.DagStatusSource;
 import org.apache.tez.dag.api.client.StatusGetOpts;
 import org.apache.tez.dag.api.client.VertexStatus;
 import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetDAGStatusRequestProto;
@@ -43,6 +49,7 @@ import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetDAGStatusResp
 import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetVertexStatusRequestProto;
 import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.GetVertexStatusResponseProto;
 import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC.TryKillDAGRequestProto;
+import org.apache.tez.dag.api.records.DAGProtos;
 import org.apache.tez.dag.api.records.DAGProtos.DAGStatusProto;
 import org.apache.tez.dag.api.records.DAGProtos.DAGStatusStateProto;
 import org.apache.tez.dag.api.records.DAGProtos.ProgressProto;
@@ -55,11 +62,14 @@ import org.apache.tez.dag.api.records.DAGProtos.VertexStatusProto;
 import org.apache.tez.dag.api.records.DAGProtos.VertexStatusStateProto;
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.ArgumentMatcher;
 import org.mockito.internal.util.collections.Sets;
 
 import com.google.protobuf.RpcController;
 import com.google.protobuf.ServiceException;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 public class TestDAGClient {
 
@@ -208,13 +218,13 @@ public class TestDAGClient {
     DAGStatus resultDagStatus = dagClient.getDAGStatus(null);
     verify(mockProxy, times(1)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
         .setDagId(dagIdStr).setTimeout(0).build());
-    assertEquals(new DAGStatus(dagStatusProtoWithoutCounters), resultDagStatus);
+    assertEquals(new DAGStatus(dagStatusProtoWithoutCounters, DagStatusSource.AM), resultDagStatus);
     System.out.println("DAGStatusWithoutCounter:" + resultDagStatus);
     
     resultDagStatus = dagClient.getDAGStatus(Sets.newSet(StatusGetOpts.GET_COUNTERS));
     verify(mockProxy, times(1)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
         .setDagId(dagIdStr).setTimeout(0).addStatusOptions(StatusGetOptsProto.GET_COUNTERS).build());
-    assertEquals(new DAGStatus(dagStatusProtoWithCounters), resultDagStatus);
+    assertEquals(new DAGStatus(dagStatusProtoWithCounters, DagStatusSource.AM), resultDagStatus);
     System.out.println("DAGStatusWithCounter:" + resultDagStatus);
   }
   
@@ -245,14 +255,20 @@ public class TestDAGClient {
   public void testWaitForCompletion() throws Exception{
     // first time return DAG_RUNNING, second time return DAG_SUCCEEDED
     when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build())
+      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters)
+          .build())
       .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
-                    (DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
-                 .build());
-      
+          (DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters)
+              .setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
+          .build());
+
     dagClient.waitForCompletion();
-    verify(mockProxy, times(2)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
-        .setDagId(dagIdStr).setTimeout(0).build());
+    ArgumentCaptor<RpcController> rpcControllerArgumentCaptor =
+        ArgumentCaptor.forClass(RpcController.class);
+    ArgumentCaptor<GetDAGStatusRequestProto> argumentCaptor =
+        ArgumentCaptor.forClass(GetDAGStatusRequestProto.class);
+    verify(mockProxy, times(2))
+        .getDAGStatus(rpcControllerArgumentCaptor.capture(), argumentCaptor.capture());
   }
 
   @Test(timeout = 5000)
@@ -260,29 +276,220 @@ public class TestDAGClient {
 
     // first time and second time return DAG_RUNNING, third time return DAG_SUCCEEDED
     when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters)
-          .build())
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithoutCounters).build())
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
-                (DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
-             .build());
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(
+            DAGStatusProto.newBuilder(dagStatusProtoWithCounters)
+                .setState(DAGStatusStateProto.DAG_RUNNING).build()).build())
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(
+            DAGStatusProto.newBuilder(dagStatusProtoWithCounters)
+                .setState(DAGStatusStateProto.DAG_RUNNING).build()).build())
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(
+            DAGStatusProto.newBuilder(dagStatusProtoWithCounters)
+                .setState(DAGStatusStateProto.DAG_RUNNING).build()).build())
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
+            (DAGStatusProto.newBuilder(dagStatusProtoWithoutCounters)
+                .setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
+            .build());
     
     //  first time for getVertexSet
     //  second & third time for check completion
+    ArgumentCaptor<RpcController> rpcControllerArgumentCaptor =
+        ArgumentCaptor.forClass(RpcController.class);
+    ArgumentCaptor<GetDAGStatusRequestProto> argumentCaptor =
+        ArgumentCaptor.forClass(GetDAGStatusRequestProto.class);
     dagClient.waitForCompletionWithStatusUpdates(null);
-    verify(mockProxy, times(3)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
-        .setDagId(dagIdStr).setTimeout(0).build());
+    // 2 from initial request - when status isn't cached. 1 for vertex names. 1 for final
wait.
+    verify(mockProxy, times(4))
+        .getDAGStatus(rpcControllerArgumentCaptor.capture(), argumentCaptor.capture());
 
-    
     when(mockProxy.getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class)))
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithCounters).build())
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(dagStatusProtoWithCounters).build())
-      .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
-                (DAGStatusProto.newBuilder(dagStatusProtoWithCounters).setState(DAGStatusStateProto.DAG_SUCCEEDED).build())
-             .build());
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(
+            DAGStatusProto.newBuilder(dagStatusProtoWithCounters)
+                .setState(DAGStatusStateProto.DAG_RUNNING).build()).build())
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(
+            DAGStatusProto.newBuilder(dagStatusProtoWithCounters)
+                .setState(DAGStatusStateProto.DAG_RUNNING).build()).build())
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus(
+            DAGStatusProto.newBuilder(dagStatusProtoWithCounters)
+                .setState(DAGStatusStateProto.DAG_RUNNING).build()).build())
+        .thenReturn(GetDAGStatusResponseProto.newBuilder().setDagStatus
+            (DAGStatusProto.newBuilder(dagStatusProtoWithCounters).setState(
+                DAGStatusStateProto.DAG_SUCCEEDED).build())
+            .build());
+
+    rpcControllerArgumentCaptor =
+        ArgumentCaptor.forClass(RpcController.class);
+    argumentCaptor =
+        ArgumentCaptor.forClass(GetDAGStatusRequestProto.class);
     dagClient.waitForCompletionWithStatusUpdates(Sets.newSet(StatusGetOpts.GET_COUNTERS));
-    verify(mockProxy, times(3)).getDAGStatus(null, GetDAGStatusRequestProto.newBuilder()
-      .setDagId(dagIdStr).setTimeout(0).addStatusOptions(StatusGetOptsProto.GET_COUNTERS).build());
+    // 4 from past invocation in the test, 2 from initial request - when status isn't cached.
1 for vertex names. 1 for final wait.
+    verify(mockProxy, times(8))
+        .getDAGStatus(rpcControllerArgumentCaptor.capture(), argumentCaptor.capture());
+  }
+
+  @Test(timeout = 50000)
+  public void testGetDagStatusWithTimeout() throws Exception {
+    long startTime;
+    long endTime;
+    long diff;
+
+    TezConfiguration tezConf = new TezConfiguration();
+    tezConf.setLong(TezConfiguration.TEZ_DAG_STATUS_POLLINTERVAL_MS, 800l);
+
+    DAGClientImplForTest dagClient = new DAGClientImplForTest(mockAppId, dagIdStr, tezConf,
null);
+    DAGClientRPCImplForTest dagClientRpc =
+        new DAGClientRPCImplForTest(mockAppId, dagIdStr, tezConf, null);
+    dagClient.setRealClient(dagClientRpc);
+
+    DAGStatus dagStatus;
+
+
+    // Fetch from RM. AM not up yet.
+    dagClientRpc.setAMProxy(null);
+    DAGStatus rmDagStatus =
+        new DAGStatus(constructDagStatusProto(DAGStatusStateProto.DAG_SUBMITTED),
+            DagStatusSource.RM);
+    dagClient.setRmDagStatus(rmDagStatus);
+
+    startTime = System.currentTimeMillis();
+    dagStatus = dagClient.getDAGStatus(EnumSet.noneOf(StatusGetOpts.class), 2000l);
+    endTime = System.currentTimeMillis();
+    diff = endTime - startTime;
+    assertTrue(diff > 1500l && diff < 2500l);
+    // One at start. Second and Third within the sleep. Fourth at final refresh.
+    assertEquals(0, dagClientRpc.numGetStatusViaAmInvocations); // No AM available, so no
invocations to AM
+    assertEquals(4, dagClient.numGetStatusViaRmInvocations);
+    assertEquals(DAGStatus.State.SUBMITTED, dagStatus.getState());
+
+    // Fetch from AM. RUNNING
+    dagClient.resetCounters();
+    dagClientRpc.resetCountesr();
+    rmDagStatus =
+        new DAGStatus(constructDagStatusProto(DAGStatusStateProto.DAG_RUNNING), DagStatusSource.RM);
+    dagClient.setRmDagStatus(rmDagStatus);
+    dagClientRpc.setAMProxy(createMockProxy(DAGStatusStateProto.DAG_RUNNING, -1));
+
+    startTime = System.currentTimeMillis();
+    dagStatus = dagClient.getDAGStatus(EnumSet.noneOf(StatusGetOpts.class), 2000l);
+    endTime = System.currentTimeMillis();
+    diff = endTime - startTime;
+    assertTrue(diff > 1500l && diff < 2500l);
+    // Directly from AM
+    assertEquals(0, dagClient.numGetStatusViaRmInvocations);
+    // Directly from AM - one refresh. One with timeout.
+    assertEquals(2, dagClientRpc.numGetStatusViaAmInvocations);
+    assertEquals(DAGStatus.State.RUNNING, dagStatus.getState());
+
+
+    // Fetch from AM. Success.
+    dagClient.resetCounters();
+    dagClientRpc.resetCountesr();
+    rmDagStatus =
+        new DAGStatus(constructDagStatusProto(DAGStatusStateProto.DAG_RUNNING), DagStatusSource.RM);
+    dagClient.setRmDagStatus(rmDagStatus);
+    dagClientRpc.setAMProxy(createMockProxy(DAGStatusStateProto.DAG_SUCCEEDED, 1000l));
+
+    startTime = System.currentTimeMillis();
+    dagStatus = dagClient.getDAGStatus(EnumSet.noneOf(StatusGetOpts.class), 2000l);
+    endTime = System.currentTimeMillis();
+    diff = endTime - startTime;
+    assertTrue(diff > 500l && diff < 1500l);
+    // Directly from AM
+    assertEquals(0, dagClient.numGetStatusViaRmInvocations);
+    // Directly from AM - previous request cached, so single invocation only.
+    assertEquals(1, dagClientRpc.numGetStatusViaAmInvocations);
+    assertEquals(DAGStatus.State.SUCCEEDED, dagStatus.getState());
+
+  }
+
+  private static class DAGClientRPCImplForTest extends DAGClientRPCImpl {
+
+    int numGetStatusViaAmInvocations = 0;
+
+    public DAGClientRPCImplForTest(ApplicationId appId, String dagId,
+                                   TezConfiguration conf,
+                                   @Nullable FrameworkClient frameworkClient) {
+      super(appId, dagId, conf, frameworkClient);
+    }
+
+    void setAMProxy(DAGClientAMProtocolBlockingPB proxy) {
+      this.proxy = proxy;
+    }
+
+    void resetCountesr() {
+      numGetStatusViaAmInvocations = 0;
+    }
+
+    @Override
+    boolean createAMProxyIfNeeded() throws IOException, TezException {
+      if (proxy == null) {
+        return false;
+      } else {
+        return true;
+      }
+    }
+
+    @Override
+    DAGStatus getDAGStatusViaAM(Set<StatusGetOpts> statusOptions, long timeout)
+        throws IOException, TezException {
+      numGetStatusViaAmInvocations++;
+      return super.getDAGStatusViaAM(statusOptions, timeout);
+    }
+  }
+
+  private static class DAGClientImplForTest extends DAGClientImpl {
+
+    private DAGStatus rmDagStatus;
+    int numGetStatusViaRmInvocations = 0;
+
+    public DAGClientImplForTest(ApplicationId appId, String dagId,
+                                TezConfiguration conf,
+                                @Nullable FrameworkClient frameworkClient) {
+      super(appId, dagId, conf, frameworkClient);
+    }
+
+    private void setRealClient(DAGClientRPCImplForTest dagClientRpcImplForTest) {
+      this.realClient = dagClientRpcImplForTest;
+    }
+
+    void setRmDagStatus(DAGStatus rmDagStatus) {
+      this.rmDagStatus = rmDagStatus;
+    }
+
+    void resetCounters() {
+      numGetStatusViaRmInvocations = 0;
+    }
+
+    @Override
+    protected DAGStatus getDAGStatusViaRM() throws TezException, IOException {
+      numGetStatusViaRmInvocations++;
+      return rmDagStatus;
+    }
+  }
+
+  private DAGProtos.DAGStatusProto.Builder constructDagStatusProto(DAGStatusStateProto stateProto)
{
+    DAGProtos.DAGStatusProto.Builder builder = DAGProtos.DAGStatusProto.newBuilder();
+    builder.setState(stateProto);
+    return builder;
+  }
+
+  private DAGClientAMProtocolBlockingPB createMockProxy(final DAGStatusStateProto stateProto,
+                                                        final long timeout) throws
+      ServiceException {
+    DAGClientAMProtocolBlockingPB mock = mock(DAGClientAMProtocolBlockingPB.class);
+
+    doAnswer(new Answer() {
+      @Override
+      public Object answer(InvocationOnMock invocation) throws Throwable {
+        GetDAGStatusRequestProto request = (GetDAGStatusRequestProto) invocation.getArguments()[1];
+        long sleepTime = request.getTimeout();
+        if (timeout != -1) {
+          sleepTime = timeout;
+        }
+        Thread.sleep(sleepTime);
+        return GetDAGStatusResponseProto.newBuilder().setDagStatus(constructDagStatusProto(
+            stateProto)).build();
+      }
+    }).when(mock).getDAGStatus(isNull(RpcController.class), any(GetDAGStatusRequestProto.class));
+    return mock;
   }
-  
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java b/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
index df3f4c7..b0a2c63 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/api/client/DAGStatusBuilder.java
@@ -32,7 +32,7 @@ import org.apache.tez.dag.app.dag.DAGState;
 public class DAGStatusBuilder extends DAGStatus {
 
   public DAGStatusBuilder() {
-    super(DAGStatusProto.newBuilder());
+    super(DAGStatusProto.newBuilder(), null);
   }
 
   public void setState(DAGState state) {

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
index 37ed365..e685f1b 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/DAGImpl.java
@@ -33,6 +33,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
+import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
@@ -149,6 +150,7 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
   private final TaskAttemptListener taskAttemptListener;
   private final TaskHeartbeatHandler taskHeartbeatHandler;
   private final Object tasksSyncHandle = new Object();
+  private final Condition dagCompleteCondition;
 
   private volatile boolean committedOrAborted = false;
   private volatile boolean allOutputsCommitted = false;
@@ -390,7 +392,6 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
   @VisibleForTesting
   boolean recoveryCommitInProgress = false;
   Map<String, Boolean> recoveredGroupCommits = new HashMap<String, Boolean>();
-  long statusPollInterval;
 
   static class VertexGroupInfo {
     String groupName;
@@ -445,7 +446,8 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
     ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
     this.readLock = readWriteLock.readLock();
     this.writeLock = readWriteLock.writeLock();
-    
+    this.dagCompleteCondition = writeLock.newCondition();
+
     this.localResources = DagTypeConverters.createLocalResourceMapFromDAGPlan(jobPlan
         .getLocalResourceList());
 
@@ -472,13 +474,6 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
     //  instance variable.
     stateMachine = stateMachineFactory.make(this);
     this.entityUpdateTracker = new StateChangeNotifier(this);
-    statusPollInterval = dagConf.getLong(
-        TezConfiguration.TEZ_DAG_STATUS_POLLINTERVAL_MS,
-        TezConfiguration.TEZ_DAG_STATUS_POLLINTERVAL_MS_DEFAULT);
-    if(statusPollInterval < 0) {
-      LOG.error("DAG Status poll interval cannot be negative and setting to default value.");
-      statusPollInterval = TezConfiguration.TEZ_DAG_STATUS_POLLINTERVAL_MS_DEFAULT;
-    }
   }
 
   protected StateMachine<DAGState, DAGEventType, DAGEvent> getStateMachine() {
@@ -748,20 +743,34 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
   }
 
   public DAGStatusBuilder getDAGStatus(Set<StatusGetOpts> statusOptions,
-      long timeout) throws TezException {
-    long currentStatusPollInterval = statusPollInterval;
-    if(timeout >= 0 && currentStatusPollInterval > timeout) {
-      currentStatusPollInterval = timeout;
-    }
-    long timeoutTime = System.currentTimeMillis() + timeout;
-    while (timeout < 0 || (timeout > 0 && timeoutTime > System.currentTimeMillis()))
{
-      if(isComplete()) {
-        break;
-      }
+                                       long timeoutMillis) throws TezException {
+    long timeoutNanos = timeoutMillis * 1000l * 1000l;
+    if (timeoutMillis < 0) {
+      // Return only on SUCCESS
+      timeoutNanos = Long.MAX_VALUE;
+    }
+    if (isComplete()) {
+      return getDAGStatus(statusOptions);
+    }
+    while (true) {
+      long nanosLeft;
+      writeLock.lock();
       try {
-        Thread.sleep(currentStatusPollInterval);
+        // Check within the lock to ensure we don't end up waiting after the notify has happened
+        if (isComplete()) {
+          break;
+        }
+        nanosLeft = dagCompleteCondition.awaitNanos(timeoutNanos);
       } catch (InterruptedException e) {
-        throw new TezException(e);
+        throw new TezException("Interrupted while waiting for dag to complete", e);
+      } finally {
+        writeLock.unlock();
+      }
+      if (nanosLeft <= 0) {
+        // Time expired.
+        break;
+      } else {
+        timeoutNanos = nanosLeft;
       }
     }
     return getDAGStatus(statusOptions);
@@ -1209,9 +1218,26 @@ public class DAGImpl implements org.apache.tez.dag.app.dag.DAG,
     }
 
     LOG.info("DAG: " + getID() + " finished with state: " + finalState);
+
+    // Signal dag completion.
+    // The state will move to the final state after the Transition which invoked this method
completes.
+    // However, it is OK to send the signal from here itself.
+    // This happens within a writeLock. The dagCompletionCondition check attempts to check
for
+    // dagCompletion within the associated lock - so it will block till the full transition
+    // completes and the state updates.
+    notifyDagFinished();
     return finalState;
   }
 
+  private void notifyDagFinished() {
+    writeLock.lock();
+    try {
+      dagCompleteCondition.signal();
+    } finally {
+      writeLock.unlock();
+    }
+  }
+
   @Override
   public String getUserName() {
     return userName;

http://git-wip-us.apache.org/repos/asf/tez/blob/60f413d9/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGImpl.java
----------------------------------------------------------------------
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGImpl.java b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGImpl.java
index 62aa453..98c8492 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGImpl.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/dag/impl/TestDAGImpl.java
@@ -25,12 +25,19 @@ import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 import java.io.IOException;
+import java.net.URL;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Callable;
 import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
 
 import org.apache.commons.lang.StringUtils;
 import org.slf4j.Logger;
@@ -61,7 +68,11 @@ import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.EdgeProperty.DataSourceType;
 import org.apache.tez.dag.api.EdgeProperty.SchedulingType;
+import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.client.DAGStatus;
+import org.apache.tez.dag.api.client.DAGStatusBuilder;
+import org.apache.tez.dag.api.client.StatusGetOpts;
 import org.apache.tez.dag.api.oldrecords.TaskState;
 import org.apache.tez.dag.api.records.DAGProtos;
 import org.apache.tez.dag.api.records.DAGProtos.ConfigurationProto;
@@ -1516,6 +1527,121 @@ public class TestDAGImpl {
   }
 
   @SuppressWarnings("unchecked")
+  @Test(timeout = 10000)
+  public void testGetDAGStatusWithWait() throws TezException {
+    initDAG(dag);
+    startDAG(dag);
+    dispatcher.await();
+
+    // All vertices except one succeed
+    for (int i = 0; i < dag.getVertices().size() - 1; ++i) {
+      dispatcher.getEventHandler().handle(new DAGEventVertexCompleted(
+          TezVertexID.getInstance(dagId, i), VertexState.SUCCEEDED));
+    }
+    dispatcher.await();
+    Assert.assertEquals(DAGState.RUNNING, dag.getState());
+    Assert.assertEquals(5, dag.getSuccessfulVertices());
+
+    long dagStatusStartTime = System.currentTimeMillis();
+    DAGStatusBuilder dagStatus = dag.getDAGStatus(EnumSet.noneOf(StatusGetOpts.class), 2000l);
+    long dagStatusEndTime = System.currentTimeMillis();
+    long diff = dagStatusEndTime - dagStatusStartTime;
+    Assert.assertTrue(diff > 1500 && diff < 2500);
+    Assert.assertEquals(DAGStatusBuilder.State.RUNNING, dagStatus.getState());
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test(timeout = 20000)
+  public void testGetDAGStatusReturnOnDagSucceeded() throws InterruptedException, TezException
{
+    runTestGetDAGStatusReturnOnDagFinished(DAGStatus.State.SUCCEEDED);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test(timeout = 20000)
+  public void testGetDAGStatusReturnOnDagFailed() throws InterruptedException, TezException
{
+    runTestGetDAGStatusReturnOnDagFinished(DAGStatus.State.FAILED);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test(timeout = 20000)
+  public void testGetDAGStatusReturnOnDagKilled() throws InterruptedException, TezException
{
+    runTestGetDAGStatusReturnOnDagFinished(DAGStatus.State.KILLED);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test(timeout = 20000)
+  public void testGetDAGStatusReturnOnDagError() throws InterruptedException, TezException
{
+    runTestGetDAGStatusReturnOnDagFinished(DAGStatus.State.ERROR);
+  }
+
+
+  @SuppressWarnings("unchecked")
+  public void runTestGetDAGStatusReturnOnDagFinished(DAGStatusBuilder.State testState) throws
TezException, InterruptedException {
+    initDAG(dag);
+    startDAG(dag);
+    dispatcher.await();
+
+    // All vertices except one succeed
+    for (int i = 0; i < dag.getVertices().size() - 1; ++i) {
+      dispatcher.getEventHandler().handle(new DAGEventVertexCompleted(
+          TezVertexID.getInstance(dagId, 0), VertexState.SUCCEEDED));
+    }
+    dispatcher.await();
+    Assert.assertEquals(DAGState.RUNNING, dag.getState());
+    Assert.assertEquals(5, dag.getSuccessfulVertices());
+
+    ReentrantLock lock = new ReentrantLock();
+    Condition startCondition = lock.newCondition();
+    Condition endCondition = lock.newCondition();
+    DagStatusCheckRunnable statusCheckRunnable =
+        new DagStatusCheckRunnable(lock, startCondition, endCondition);
+    Thread t1 = new Thread(statusCheckRunnable);
+    t1.start();
+    lock.lock();
+    try {
+      while (!statusCheckRunnable.started.get()) {
+        startCondition.await();
+      }
+    } finally {
+      lock.unlock();
+    }
+
+    // Sleep for 2 seconds. Then mark the last vertex is successful.
+    Thread.sleep(2000l);
+    if (testState == DAGStatus.State.SUCCEEDED) {
+      dispatcher.getEventHandler().handle(new DAGEventVertexCompleted(
+          TezVertexID.getInstance(dagId, 5), VertexState.SUCCEEDED));
+    } else if (testState == DAGStatus.State.FAILED) {
+      dispatcher.getEventHandler().handle(new DAGEventVertexCompleted(
+          TezVertexID.getInstance(dagId, 5), VertexState.FAILED));
+    } else if (testState == DAGStatus.State.KILLED) {
+      dispatcher.getEventHandler().handle(new DAGEvent(dagId, DAGEventType.DAG_KILL));
+    } else if (testState == DAGStatus.State.ERROR) {
+      dispatcher.getEventHandler().handle(new DAGEventStartDag(dagId, new LinkedList<URL>()));
+    } else {
+      throw new UnsupportedOperationException("Unsupported state for test: " + testState);
+    }
+    dispatcher.await();
+
+    // Wait for the dag status to return
+    lock.lock();
+    try {
+      while (!statusCheckRunnable.ended.get()) {
+        endCondition.await();
+      }
+    } finally {
+      lock.unlock();
+    }
+
+    long diff = statusCheckRunnable.dagStatusEndTime - statusCheckRunnable.dagStatusStartTime;
+    Assert.assertNotNull(statusCheckRunnable.dagStatus);
+    Assert.assertTrue(diff > 1000 && diff < 3500);
+    Assert.assertEquals(testState, statusCheckRunnable.dagStatus.getState());
+    t1.join();
+  }
+
+
+  @SuppressWarnings("unchecked")
   @Test(timeout = 5000)
   public void testVertexFailureHandling() {
     initDAG(dag);
@@ -1718,4 +1844,52 @@ public class TestDAGImpl {
       return 0;
     }
   }
+
+
+  // Specificially for testGetDAGStatusReturnOnDagSuccess
+  private class DagStatusCheckRunnable implements Runnable {
+
+    private volatile DAGStatusBuilder dagStatus;
+    private volatile long dagStatusStartTime = -1;
+    private volatile long dagStatusEndTime = -1;
+    private final AtomicBoolean started = new AtomicBoolean(false);
+    private final AtomicBoolean ended = new AtomicBoolean(false);
+
+    private final ReentrantLock lock;
+    private final Condition startCondition;
+    private final Condition endCondition;
+
+    public DagStatusCheckRunnable(ReentrantLock lock,
+                                  Condition startCondition,
+                                  Condition endCondition) {
+      this.lock = lock;
+      this.startCondition = startCondition;
+      this.endCondition = endCondition;
+    }
+
+    @Override
+    public void run() {
+      started.set(true);
+      lock.lock();
+      try {
+        startCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+      try {
+        dagStatusStartTime = System.currentTimeMillis();
+        dagStatus = dag.getDAGStatus(EnumSet.noneOf(StatusGetOpts.class), 10000l);
+        dagStatusEndTime = System.currentTimeMillis();
+      } catch (TezException e) {
+
+      }
+      lock.lock();
+      ended.set(true);
+      try {
+        endCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+    }
+  }
 }


Mime
View raw message