hive-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gunt...@apache.org
Subject svn commit: r1669277 - in /hive/trunk/ql/src: java/org/apache/hadoop/hive/ql/exec/tez/ test/org/apache/hadoop/hive/ql/exec/tez/
Date Thu, 26 Mar 2015 06:09:46 GMT
Author: gunther
Date: Thu Mar 26 06:09:45 2015
New Revision: 1669277

URL: http://svn.apache.org/r1669277
Log:
HIVE-9976: Possible race condition in DynamicPartitionPruner for <200ms tasks (Siddharth Seth via Gunther Hagleitner)

Added:
    hive/trunk/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java
Modified:
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java?rev=1669277&r1=1669276&r2=1669277&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java Thu Mar 26 06:09:45 2015
@@ -274,9 +274,8 @@ public class CustomPartitionVertex exten
         for (Integer key : bucketToInitialSplitMap.keySet()) {
           InputSplit[] inputSplitArray =
               (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0]));
-          HiveSplitGenerator hiveSplitGenerator = new HiveSplitGenerator();
           Multimap<Integer, InputSplit> groupedSplit =
-              hiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves,
+              grouper.generateGroupedSplits(jobConf, conf, inputSplitArray, waves,
                   availableSlots, inputName, mainWorkName.isEmpty());
           if (mainWorkName.isEmpty() == false) {
             Multimap<Integer, InputSplit> singleBucketToGroupedSplit =
@@ -295,11 +294,10 @@ public class CustomPartitionVertex exten
         // grouped split. This would affect SMB joins where we want to find the smallest key in
         // all the bucket files.
         for (Integer key : bucketToInitialSplitMap.keySet()) {
-          HiveSplitGenerator hiveSplitGenerator = new HiveSplitGenerator();
           InputSplit[] inputSplitArray =
               (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0]));
           Multimap<Integer, InputSplit> groupedSplit =
-              hiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves,
+              grouper.generateGroupedSplits(jobConf, conf, inputSplitArray, waves,
                     availableSlots, inputName, false);
             bucketToGroupedSplitMap.putAll(key, groupedSplit.values());
         }

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java?rev=1669277&r1=1669276&r2=1669277&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java Thu Mar 26 06:09:45 2015
@@ -31,12 +31,13 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.ConcurrentSkipListSet;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
 
-import javolution.testing.AssertionException;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
@@ -74,30 +75,47 @@ public class DynamicPartitionPruner {
 
   private static final Log LOG = LogFactory.getLog(DynamicPartitionPruner.class);
 
+  private final InputInitializerContext context;
+  private final MapWork work;
+  private final JobConf jobConf;
+
+
   private final Map<String, List<SourceInfo>> sourceInfoMap =
       new HashMap<String, List<SourceInfo>>();
 
   private final BytesWritable writable = new BytesWritable();
 
+  /* Keeps track of all events that need to be processed - irrespective of the source */
   private final BlockingQueue<Object> queue = new LinkedBlockingQueue<Object>();
 
+  /* Keeps track of vertices from which events are expected */
   private final Set<String> sourcesWaitingForEvents = new HashSet<String>();
 
+  // Stores negative values to count columns. Eventually set to #tasks X #columns after the source vertex completes.
+  private final Map<String, MutableInt> numExpectedEventsPerSource = new HashMap<>();
+  private final Map<String, MutableInt> numEventsSeenPerSource = new HashMap<>();
+
   private int sourceInfoCount = 0;
 
   private final Object endOfEvents = new Object();
 
   private int totalEventCount = 0;
 
-  public DynamicPartitionPruner() {
+  public DynamicPartitionPruner(InputInitializerContext context, MapWork work, JobConf jobConf) throws
+      SerDeException {
+    this.context = context;
+    this.work = work;
+    this.jobConf = jobConf;
+    synchronized (this) {
+      initialize();
+    }
   }
 
-  public void prune(MapWork work, JobConf jobConf, InputInitializerContext context)
+  public void prune()
       throws SerDeException, IOException,
       InterruptedException, HiveException {
 
     synchronized(sourcesWaitingForEvents) {
-      initialize(work, jobConf);
 
       if (sourcesWaitingForEvents.isEmpty()) {
         return;
@@ -112,11 +130,11 @@ public class DynamicPartitionPruner {
       }
     }
 
-    LOG.info("Waiting for events (" + sourceInfoCount + " items) ...");
+    LOG.info("Waiting for events (" + sourceInfoCount + " sources) ...");
     // synchronous event processing loop. Won't return until all events have
     // been processed.
     this.processEvents();
-    this.prunePartitions(work, context);
+    this.prunePartitions();
     LOG.info("Ok to proceed.");
   }
 
@@ -129,25 +147,38 @@ public class DynamicPartitionPruner {
     sourceInfoCount = 0;
   }
 
-  public void initialize(MapWork work, JobConf jobConf) throws SerDeException {
+  private void initialize() throws SerDeException {
     this.clear();
     Map<String, SourceInfo> columnMap = new HashMap<String, SourceInfo>();
+    // sources represent vertex names
     Set<String> sources = work.getEventSourceTableDescMap().keySet();
 
     sourcesWaitingForEvents.addAll(sources);
 
     for (String s : sources) {
+      // Set to 0 to start with. This will be decremented for all columns for which events
+      // are generated by this source - which is eventually used to determine number of expected
+      // events for the source. #colums X #tasks
+      numExpectedEventsPerSource.put(s, new MutableInt(0));
+      numEventsSeenPerSource.put(s, new MutableInt(0));
+      // Virtual relation generated by the reduce sync
       List<TableDesc> tables = work.getEventSourceTableDescMap().get(s);
+      // Real column name - on which the operation is being performed
       List<String> columnNames = work.getEventSourceColumnNameMap().get(s);
+      // Expression for the operation. e.g. N^2 > 10
       List<ExprNodeDesc> partKeyExprs = work.getEventSourcePartKeyExprMap().get(s);
+      // eventSourceTableDesc, eventSourceColumnName, evenSourcePartKeyExpr move in lock-step.
+      // One entry is added to each at the same time
 
       Iterator<String> cit = columnNames.iterator();
       Iterator<ExprNodeDesc> pit = partKeyExprs.iterator();
+      // A single source can process multiple columns, and will send an event for each of them.
       for (TableDesc t : tables) {
+        numExpectedEventsPerSource.get(s).decrement();
         ++sourceInfoCount;
         String columnName = cit.next();
         ExprNodeDesc partKeyExpr = pit.next();
-        SourceInfo si = new SourceInfo(t, partKeyExpr, columnName, jobConf);
+        SourceInfo si = createSourceInfo(t, partKeyExpr, columnName, jobConf);
         if (!sourceInfoMap.containsKey(s)) {
           sourceInfoMap.put(s, new ArrayList<SourceInfo>());
         }
@@ -157,6 +188,8 @@ public class DynamicPartitionPruner {
         // We could have multiple sources restrict the same column, need to take
         // the union of the values in that case.
         if (columnMap.containsKey(columnName)) {
+          // All Sources are initialized up front. Events from different sources will end up getting added to the same list.
+          // Pruning is disabled if either source sends in an event which causes pruning to be skipped
           si.values = columnMap.get(columnName).values;
           si.skipPruning = columnMap.get(columnName).skipPruning;
         }
@@ -165,25 +198,27 @@ public class DynamicPartitionPruner {
     }
   }
 
-  private void prunePartitions(MapWork work, InputInitializerContext context) throws HiveException {
+  private void prunePartitions() throws HiveException {
     int expectedEvents = 0;
-    for (String source : this.sourceInfoMap.keySet()) {
-      for (SourceInfo si : this.sourceInfoMap.get(source)) {
+    for (Map.Entry<String, List<SourceInfo>> entry : this.sourceInfoMap.entrySet()) {
+      String source = entry.getKey();
+      for (SourceInfo si : entry.getValue()) {
         int taskNum = context.getVertexNumTasks(source);
-        LOG.info("Expecting " + taskNum + " events for vertex " + source);
+        LOG.info("Expecting " + taskNum + " events for vertex " + source + ", for column " + si.columnName);
         expectedEvents += taskNum;
-        prunePartitionSingleSource(source, si, work);
+        prunePartitionSingleSource(source, si);
       }
     }
 
     // sanity check. all tasks must submit events for us to succeed.
     if (expectedEvents != totalEventCount) {
       LOG.error("Expecting: " + expectedEvents + ", received: " + totalEventCount);
-      throw new HiveException("Incorrect event count in dynamic parition pruning");
+      throw new HiveException("Incorrect event count in dynamic partition pruning");
     }
   }
 
-  private void prunePartitionSingleSource(String source, SourceInfo si, MapWork work)
+  @VisibleForTesting
+  protected void prunePartitionSingleSource(String source, SourceInfo si)
       throws HiveException {
 
     if (si.skipPruning.get()) {
@@ -223,11 +258,11 @@ public class DynamicPartitionPruner {
     ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey);
     eval.initialize(soi);
 
-    applyFilterToPartitions(work, converter, eval, columnName, values);
+    applyFilterToPartitions(converter, eval, columnName, values);
   }
 
   @SuppressWarnings("rawtypes")
-  private void applyFilterToPartitions(MapWork work, Converter converter, ExprNodeEvaluator eval,
+  private void applyFilterToPartitions(Converter converter, ExprNodeEvaluator eval,
       String columnName, Set<Object> values) throws HiveException {
 
     Object[] row = new Object[1];
@@ -238,12 +273,12 @@ public class DynamicPartitionPruner {
       PartitionDesc desc = work.getPathToPartitionInfo().get(p);
       Map<String, String> spec = desc.getPartSpec();
       if (spec == null) {
-        throw new AssertionException("No partition spec found in dynamic pruning");
+        throw new IllegalStateException("No partition spec found in dynamic pruning");
       }
 
       String partValueString = spec.get(columnName);
       if (partValueString == null) {
-        throw new AssertionException("Could not find partition value for column: " + columnName);
+        throw new IllegalStateException("Could not find partition value for column: " + columnName);
       }
 
       Object partValue = converter.convert(partValueString);
@@ -267,17 +302,38 @@ public class DynamicPartitionPruner {
     }
   }
 
+  @VisibleForTesting
+  protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName,
+                                        JobConf jobConf) throws
+      SerDeException {
+    return new SourceInfo(t, partKeyExpr, columnName, jobConf);
+
+  }
+
   @SuppressWarnings("deprecation")
-  private static class SourceInfo {
+  @VisibleForTesting
+  static class SourceInfo {
     public final ExprNodeDesc partKey;
     public final Deserializer deserializer;
     public final StructObjectInspector soi;
     public final StructField field;
     public final ObjectInspector fieldInspector;
+    /* List of partitions that are required - populated from processing each event */
     public Set<Object> values = new HashSet<Object>();
+    /* Whether to skipPruning - depends on the payload from an event which may signal skip - if the event payload is too large */
     public AtomicBoolean skipPruning = new AtomicBoolean();
     public final String columnName;
 
+    @VisibleForTesting // Only used for testing.
+    SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf, Object forTesting) {
+      this.partKey = partKey;
+      this.columnName = columnName;
+      this.deserializer = null;
+      this.soi = null;
+      this.field = null;
+      this.fieldInspector = null;
+    }
+
     public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf)
         throws SerDeException {
 
@@ -328,52 +384,60 @@ public class DynamicPartitionPruner {
   }
 
   @SuppressWarnings("deprecation")
-  private String processPayload(ByteBuffer payload, String sourceName) throws SerDeException,
+  @VisibleForTesting
+  protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException,
       IOException {
 
     DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload));
-    String columnName = in.readUTF();
-    boolean skip = in.readBoolean();
+    try {
+      String columnName = in.readUTF();
 
-    LOG.info("Source of event: " + sourceName);
+      LOG.info("Source of event: " + sourceName);
 
-    List<SourceInfo> infos = this.sourceInfoMap.get(sourceName);
-    if (infos == null) {
-      in.close();
-      throw new AssertionException("no source info for event source: " + sourceName);
-    }
-
-    SourceInfo info = null;
-    for (SourceInfo si : infos) {
-      if (columnName.equals(si.columnName)) {
-        info = si;
-        break;
+      List<SourceInfo> infos = this.sourceInfoMap.get(sourceName);
+      if (infos == null) {
+        throw new IllegalStateException("no source info for event source: " + sourceName);
       }
-    }
-
-    if (info == null) {
-      in.close();
-      throw new AssertionException("no source info for column: " + columnName);
-    }
 
-    if (skip) {
-      info.skipPruning.set(true);
-    }
-
-    while (payload.hasRemaining()) {
-      writable.readFields(in);
-
-      Object row = info.deserializer.deserialize(writable);
+      SourceInfo info = null;
+      for (SourceInfo si : infos) {
+        if (columnName.equals(si.columnName)) {
+          info = si;
+          break;
+        }
+      }
 
-      Object value = info.soi.getStructFieldData(row, info.field);
-      value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);
+      if (info == null) {
+        throw new IllegalStateException("no source info for column: " + columnName);
+      }
 
-      if (LOG.isDebugEnabled()) {
-        LOG.debug("Adding: " + value + " to list of required partitions");
+      if (info.skipPruning.get()) {
+        // Marked as skipped previously. Don't bother processing the rest of the payload.
+      } else {
+        boolean skip = in.readBoolean();
+        if (skip) {
+          info.skipPruning.set(true);
+        } else {
+          while (payload.hasRemaining()) {
+            writable.readFields(in);
+
+            Object row = info.deserializer.deserialize(writable);
+
+            Object value = info.soi.getStructFieldData(row, info.field);
+            value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);
+
+            if (LOG.isDebugEnabled()) {
+              LOG.debug("Adding: " + value + " to list of required partitions");
+            }
+            info.values.add(value);
+          }
+        }
+      }
+    } finally {
+      if (in != null) {
+        in.close();
       }
-      info.values.add(value);
     }
-    in.close();
     return sourceName;
   }
 
@@ -409,23 +473,47 @@ public class DynamicPartitionPruner {
     synchronized(sourcesWaitingForEvents) {
       if (sourcesWaitingForEvents.contains(event.getSourceVertexName())) {
         ++totalEventCount;
+        numEventsSeenPerSource.get(event.getSourceVertexName()).increment();
         queue.offer(event);
+        checkForSourceCompletion(event.getSourceVertexName());
       }
     }
   }
 
   public void processVertex(String name) {
     LOG.info("Vertex succeeded: " + name);
-
     synchronized(sourcesWaitingForEvents) {
-      sourcesWaitingForEvents.remove(name);
+      // Get a deterministic count of number of tasks for the vertex.
+      MutableInt prevVal = numExpectedEventsPerSource.get(name);
+      int prevValInt = prevVal.intValue();
+      Preconditions.checkState(prevValInt < 0,
+          "Invalid value for numExpectedEvents for source: " + name + ", oldVal=" + prevValInt);
+      prevVal.setValue((-1) * prevValInt * context.getVertexNumTasks(name));
+      checkForSourceCompletion(name);
+    }
+  }
 
-      if (sourcesWaitingForEvents.isEmpty()) {
-        // we've got what we need; mark the queue
-        queue.offer(endOfEvents);
-      } else {
-        LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " events.");
+  private void checkForSourceCompletion(String name) {
+    int expectedEvents = numExpectedEventsPerSource.get(name).getValue();
+    if (expectedEvents < 0) {
+      // Expected events not updated yet - vertex SUCCESS notification not received.
+      return;
+    } else {
+      int processedEvents = numEventsSeenPerSource.get(name).getValue();
+      if (processedEvents == expectedEvents) {
+        sourcesWaitingForEvents.remove(name);
+        if (sourcesWaitingForEvents.isEmpty()) {
+          // we've got what we need; mark the queue
+          queue.offer(endOfEvents);
+        } else {
+          LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " sources.");
+        }
+      } else if (processedEvents > expectedEvents) {
+        throw new IllegalStateException(
+            "Received too many events for " + name + ", Expected=" + expectedEvents +
+                ", Received=" + processedEvents);
       }
+      return;
     }
   }
 }

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java?rev=1669277&r1=1669276&r2=1669277&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java Thu Mar 26 06:09:45 2015
@@ -19,22 +19,17 @@
 package org.apache.hadoop.hive.ql.exec.tez;
 
 import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 
+import com.google.common.base.Preconditions;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.hive.common.JavaUtils;
 import org.apache.hadoop.hive.ql.exec.Utilities;
-import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils;
 import org.apache.hadoop.hive.ql.plan.MapWork;
-import org.apache.hadoop.hive.ql.plan.PartitionDesc;
+import org.apache.hadoop.hive.serde2.SerDeException;
 import org.apache.hadoop.hive.shims.ShimLoader;
-import org.apache.hadoop.mapred.FileSplit;
 import org.apache.hadoop.mapred.InputFormat;
 import org.apache.hadoop.mapred.InputSplit;
 import org.apache.hadoop.mapred.JobConf;
@@ -57,7 +52,6 @@ import org.apache.tez.runtime.api.events
 import org.apache.tez.runtime.api.events.InputDataInformationEvent;
 import org.apache.tez.runtime.api.events.InputInitializerEvent;
 
-import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Multimap;
 
@@ -71,43 +65,44 @@ public class HiveSplitGenerator extends
 
   private static final Log LOG = LogFactory.getLog(HiveSplitGenerator.class);
 
-  private static final SplitGrouper grouper = new SplitGrouper();
-  private final DynamicPartitionPruner pruner = new DynamicPartitionPruner();
-  private InputInitializerContext context;
-  private static Map<Map<String, PartitionDesc>, Map<String, PartitionDesc>> cache =
-      new HashMap<Map<String, PartitionDesc>, Map<String, PartitionDesc>>();
+  private final DynamicPartitionPruner pruner;
+  private final Configuration conf;
+  private final JobConf jobConf;
+  private final MRInputUserPayloadProto userPayloadProto;
+  private final SplitGrouper splitGrouper = new SplitGrouper();
 
-  public HiveSplitGenerator(InputInitializerContext initializerContext) {
+
+  public HiveSplitGenerator(InputInitializerContext initializerContext) throws IOException,
+      SerDeException {
     super(initializerContext);
-  }
+    Preconditions.checkNotNull(initializerContext);
+    userPayloadProto =
+        MRInputHelpers.parseMRInputPayload(initializerContext.getInputUserPayload());
 
-  public HiveSplitGenerator() {
-    this(null);
-  }
+    this.conf =
+        TezUtils.createConfFromByteString(userPayloadProto.getConfigurationBytes());
 
-  @Override
-  public List<Event> initialize() throws Exception {
-    InputInitializerContext rootInputContext = getContext();
+    this.jobConf = new JobConf(conf);
+    // Read all credentials into the credentials instance stored in JobConf.
+    ShimLoader.getHadoopShims().getMergedCredentials(jobConf);
 
-    context = rootInputContext;
+    MapWork work = Utilities.getMapWork(jobConf);
 
-    MRInputUserPayloadProto userPayloadProto =
-        MRInputHelpers.parseMRInputPayload(rootInputContext.getInputUserPayload());
+    // Events can start coming in the moment the InputInitializer is created. The pruner
+    // must be setup and initialized here so that it sets up it's structures to start accepting events.
+    // Setting it up in initialize leads to a window where events may come in before the pruner is
+    // initialized, which may cause it to drop events.
+    pruner = new DynamicPartitionPruner(initializerContext, work, jobConf);
 
-    Configuration conf =
-        TezUtils.createConfFromByteString(userPayloadProto.getConfigurationBytes());
+  }
 
+  @Override
+  public List<Event> initialize() throws Exception {
     boolean sendSerializedEvents =
         conf.getBoolean("mapreduce.tez.input.initializer.serialize.event.payload", true);
 
-    // Read all credentials into the credentials instance stored in JobConf.
-    JobConf jobConf = new JobConf(conf);
-    ShimLoader.getHadoopShims().getMergedCredentials(jobConf);
-
-    MapWork work = Utilities.getMapWork(jobConf);
-
     // perform dynamic partition pruning
-    pruner.prune(work, jobConf, context);
+    pruner.prune();
 
     InputSplitInfoMem inputSplitInfo = null;
     String realInputFormatName = conf.get("mapred.input.format.class");
@@ -118,8 +113,8 @@ public class HiveSplitGenerator extends
           (InputFormat<?, ?>) ReflectionUtils.newInstance(JavaUtils.loadClass(realInputFormatName),
               jobConf);
 
-      int totalResource = rootInputContext.getTotalAvailableResource().getMemory();
-      int taskResource = rootInputContext.getVertexTaskResource().getMemory();
+      int totalResource = getContext().getTotalAvailableResource().getMemory();
+      int taskResource = getContext().getVertexTaskResource().getMemory();
       int availableSlots = totalResource / taskResource;
 
       // Create the un-grouped splits
@@ -132,12 +127,12 @@ public class HiveSplitGenerator extends
           + " available slots, " + waves + " waves. Input format is: " + realInputFormatName);
 
       Multimap<Integer, InputSplit> groupedSplits =
-          generateGroupedSplits(jobConf, conf, splits, waves, availableSlots);
+          splitGrouper.generateGroupedSplits(jobConf, conf, splits, waves, availableSlots);
       // And finally return them in a flat array
       InputSplit[] flatSplits = groupedSplits.values().toArray(new InputSplit[0]);
       LOG.info("Number of grouped splits: " + flatSplits.length);
 
-      List<TaskLocationHint> locationHints = grouper.createTaskLocationHints(flatSplits);
+      List<TaskLocationHint> locationHints = splitGrouper.createTaskLocationHints(flatSplits);
 
       Utilities.clearWork(jobConf);
 
@@ -158,87 +153,7 @@ public class HiveSplitGenerator extends
   }
 
 
-  public Multimap<Integer, InputSplit> generateGroupedSplits(JobConf jobConf,
-      Configuration conf, InputSplit[] splits, float waves, int availableSlots)
-      throws Exception {
-    return generateGroupedSplits(jobConf, conf, splits, waves, availableSlots, null, true);
-  }
-
-  public Multimap<Integer, InputSplit> generateGroupedSplits(JobConf jobConf,
-      Configuration conf, InputSplit[] splits, float waves, int availableSlots, String inputName,
-      boolean groupAcrossFiles) throws Exception {
-
-    MapWork work = populateMapWork(jobConf, inputName);
-    Multimap<Integer, InputSplit> bucketSplitMultiMap =
-        ArrayListMultimap.<Integer, InputSplit> create();
-
-    int i = 0;
-    InputSplit prevSplit = null;
-    for (InputSplit s : splits) {
-      // this is the bit where we make sure we don't group across partition
-      // schema boundaries
-      if (schemaEvolved(s, prevSplit, groupAcrossFiles, work)) {
-        ++i;
-        prevSplit = s;
-      }
-      bucketSplitMultiMap.put(i, s);
-    }
-    LOG.info("# Src groups for split generation: " + (i + 1));
-
-    // group them into the chunks we want
-    Multimap<Integer, InputSplit> groupedSplits =
-        grouper.group(jobConf, bucketSplitMultiMap, availableSlots, waves);
-
-    return groupedSplits;
-  }
-
-  private MapWork populateMapWork(JobConf jobConf, String inputName) {
-    MapWork work = null;
-    if (inputName != null) {
-      work = (MapWork) Utilities.getMergeWork(jobConf, inputName);
-      // work can still be null if there is no merge work for this input
-    }
-    if (work == null) {
-      work = Utilities.getMapWork(jobConf);
-    }
 
-    return work;
-  }
-
-  public boolean schemaEvolved(InputSplit s, InputSplit prevSplit, boolean groupAcrossFiles,
-      MapWork work) throws IOException {
-    boolean retval = false;
-    Path path = ((FileSplit) s).getPath();
-    PartitionDesc pd =
-        HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(),
-            path, cache);
-    String currentDeserializerClass = pd.getDeserializerClassName();
-    Class<?> currentInputFormatClass = pd.getInputFileFormatClass();
-
-    Class<?> previousInputFormatClass = null;
-    String previousDeserializerClass = null;
-    if (prevSplit != null) {
-      Path prevPath = ((FileSplit) prevSplit).getPath();
-      if (!groupAcrossFiles) {
-        return !path.equals(prevPath);
-      }
-      PartitionDesc prevPD =
-          HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(),
-              prevPath, cache);
-      previousDeserializerClass = prevPD.getDeserializerClassName();
-      previousInputFormatClass = prevPD.getInputFileFormatClass();
-    }
-
-    if ((currentInputFormatClass != previousInputFormatClass)
-        || (!currentDeserializerClass.equals(previousDeserializerClass))) {
-      retval = true;
-    }
-
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("Adding split " + path + " to src new group? " + retval);
-    }
-    return retval;
-  }
 
   private List<Event> createEventList(boolean sendSerializedEvents, InputSplitInfoMem inputSplitInfo) {
 

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java?rev=1669277&r1=1669276&r2=1669277&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java Thu Mar 26 06:09:45 2015
@@ -26,13 +26,20 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.ql.exec.Utilities;
+import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils;
 import org.apache.hadoop.hive.ql.io.HiveInputFormat;
+import org.apache.hadoop.hive.ql.plan.MapWork;
+import org.apache.hadoop.hive.ql.plan.PartitionDesc;
 import org.apache.hadoop.mapred.FileSplit;
 import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.split.TezGroupedSplit;
 import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper;
 import org.apache.tez.dag.api.TaskLocationHint;
@@ -49,8 +56,15 @@ public class SplitGrouper {
 
   private static final Log LOG = LogFactory.getLog(SplitGrouper.class);
 
+  // TODO This needs to be looked at. Map of Map to Map... Made concurrent for now since split generation
+  // can happen in parallel.
+  private static final Map<Map<String, PartitionDesc>, Map<String, PartitionDesc>> cache =
+      new ConcurrentHashMap<>();
+
   private final TezMapredSplitsGrouper tezGrouper = new TezMapredSplitsGrouper();
 
+
+
   /**
    * group splits for each bucket separately - while evenly filling all the
    * available slots with tasks
@@ -87,12 +101,83 @@ public class SplitGrouper {
     return bucketGroupedSplitMultimap;
   }
 
+
+  /**
+   * Create task location hints from a set of input splits
+   * @param splits the actual splits
+   * @return taskLocationHints - 1 per input split specified
+   * @throws IOException
+   */
+  public List<TaskLocationHint> createTaskLocationHints(InputSplit[] splits) throws IOException {
+
+    List<TaskLocationHint> locationHints = Lists.newArrayListWithCapacity(splits.length);
+
+    for (InputSplit split : splits) {
+      String rack = (split instanceof TezGroupedSplit) ? ((TezGroupedSplit) split).getRack() : null;
+      if (rack == null) {
+        if (split.getLocations() != null) {
+          locationHints.add(TaskLocationHint.createTaskLocationHint(new HashSet<String>(Arrays.asList(split
+              .getLocations())), null));
+        } else {
+          locationHints.add(TaskLocationHint.createTaskLocationHint(null, null));
+        }
+      } else {
+        locationHints.add(TaskLocationHint.createTaskLocationHint(null, Collections.singleton(rack)));
+      }
+    }
+
+    return locationHints;
+  }
+
+  /** Generate groups of splits, separated by schema evolution boundaries */
+  public Multimap<Integer, InputSplit> generateGroupedSplits(JobConf jobConf,
+                                                                    Configuration conf,
+                                                                    InputSplit[] splits,
+                                                                    float waves, int availableSlots)
+      throws Exception {
+    return generateGroupedSplits(jobConf, conf, splits, waves, availableSlots, null, true);
+  }
+
+  /** Generate groups of splits, separated by schema evolution boundaries */
+  public Multimap<Integer, InputSplit> generateGroupedSplits(JobConf jobConf,
+                                                                    Configuration conf,
+                                                                    InputSplit[] splits,
+                                                                    float waves, int availableSlots,
+                                                                    String inputName,
+                                                                    boolean groupAcrossFiles) throws
+      Exception {
+
+    MapWork work = populateMapWork(jobConf, inputName);
+    Multimap<Integer, InputSplit> bucketSplitMultiMap =
+        ArrayListMultimap.<Integer, InputSplit> create();
+
+    int i = 0;
+    InputSplit prevSplit = null;
+    for (InputSplit s : splits) {
+      // this is the bit where we make sure we don't group across partition
+      // schema boundaries
+      if (schemaEvolved(s, prevSplit, groupAcrossFiles, work)) {
+        ++i;
+        prevSplit = s;
+      }
+      bucketSplitMultiMap.put(i, s);
+    }
+    LOG.info("# Src groups for split generation: " + (i + 1));
+
+    // group them into the chunks we want
+    Multimap<Integer, InputSplit> groupedSplits =
+        this.group(jobConf, bucketSplitMultiMap, availableSlots, waves);
+
+    return groupedSplits;
+  }
+
+
   /**
    * get the size estimates for each bucket in tasks. This is used to make sure
    * we allocate the head room evenly
    */
   private Map<Integer, Integer> estimateBucketSizes(int availableSlots, float waves,
-      Map<Integer, Collection<InputSplit>> bucketSplitMap) {
+                                                    Map<Integer, Collection<InputSplit>> bucketSplitMap) {
 
     // mapping of bucket id to size of all splits in bucket in bytes
     Map<Integer, Long> bucketSizeMap = new HashMap<Integer, Long>();
@@ -147,24 +232,54 @@ public class SplitGrouper {
     return bucketTaskMap;
   }
 
-  public List<TaskLocationHint> createTaskLocationHints(InputSplit[] splits) throws IOException {
+  private static MapWork populateMapWork(JobConf jobConf, String inputName) {
+    MapWork work = null;
+    if (inputName != null) {
+      work = (MapWork) Utilities.getMergeWork(jobConf, inputName);
+      // work can still be null if there is no merge work for this input
+    }
+    if (work == null) {
+      work = Utilities.getMapWork(jobConf);
+    }
 
-    List<TaskLocationHint> locationHints = Lists.newArrayListWithCapacity(splits.length);
+    return work;
+  }
 
-    for (InputSplit split : splits) {
-      String rack = (split instanceof TezGroupedSplit) ? ((TezGroupedSplit) split).getRack() : null;
-      if (rack == null) {
-        if (split.getLocations() != null) {
-          locationHints.add(TaskLocationHint.createTaskLocationHint(new HashSet<String>(Arrays.asList(split
-              .getLocations())), null));
-        } else {
-          locationHints.add(TaskLocationHint.createTaskLocationHint(null, null));
-        }
-      } else {
-        locationHints.add(TaskLocationHint.createTaskLocationHint(null, Collections.singleton(rack)));
+  private boolean schemaEvolved(InputSplit s, InputSplit prevSplit, boolean groupAcrossFiles,
+                                       MapWork work) throws IOException {
+    boolean retval = false;
+    Path path = ((FileSplit) s).getPath();
+    PartitionDesc pd =
+        HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(),
+            path, cache);
+    String currentDeserializerClass = pd.getDeserializerClassName();
+    Class<?> currentInputFormatClass = pd.getInputFileFormatClass();
+
+    Class<?> previousInputFormatClass = null;
+    String previousDeserializerClass = null;
+    if (prevSplit != null) {
+      Path prevPath = ((FileSplit) prevSplit).getPath();
+      if (!groupAcrossFiles) {
+        return !path.equals(prevPath);
       }
+      PartitionDesc prevPD =
+          HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(),
+              prevPath, cache);
+      previousDeserializerClass = prevPD.getDeserializerClassName();
+      previousInputFormatClass = prevPD.getInputFileFormatClass();
     }
 
-    return locationHints;
+    if ((currentInputFormatClass != previousInputFormatClass)
+        || (!currentDeserializerClass.equals(previousDeserializerClass))) {
+      retval = true;
+    }
+
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Adding split " + path + " to src new group? " + retval);
+    }
+    return retval;
   }
+
+
+
 }

Added: hive/trunk/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java?rev=1669277&view=auto
==============================================================================
--- hive/trunk/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java (added)
+++ hive/trunk/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java Thu Mar 26 06:09:45 2015
@@ -0,0 +1,532 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ *  you may not use this file except in compliance with the License.
+ *  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
+import org.apache.hadoop.hive.ql.plan.MapWork;
+import org.apache.hadoop.hive.ql.plan.TableDesc;
+import org.apache.hadoop.hive.serde2.SerDeException;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.tez.runtime.api.InputInitializerContext;
+import org.apache.tez.runtime.api.events.InputInitializerEvent;
+import org.junit.Test;
+
+public class TestDynamicPartitionPruner {
+
+  @Test(timeout = 5000)
+  public void testNoPruning() throws InterruptedException, IOException, HiveException,
+      SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    MapWork mapWork = mock(MapWork.class);
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+      pruneRunnable.awaitEnd();
+      // Return immediately. No entries found for pruning. Verified via the timeout.
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testSingleSourceOrdering1() throws InterruptedException, IOException, HiveException,
+      SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(1).when(mockInitContext).getVertexNumTasks("v1");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent event =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      event.setSourceVertexName("v1");
+
+      pruner.addEvent(event);
+      pruner.processVertex("v1");
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testSingleSourceOrdering2() throws InterruptedException, IOException, HiveException,
+      SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(1).when(mockInitContext).getVertexNumTasks("v1");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent event =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      event.setSourceVertexName("v1");
+
+      pruner.processVertex("v1");
+      pruner.addEvent(event);
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testSingleSourceMultipleFiltersOrdering1() throws InterruptedException, SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(2).when(mockInitContext).getVertexNumTasks("v1");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 2));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent event =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      event.setSourceVertexName("v1");
+
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+      pruner.processVertex("v1");
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testSingleSourceMultipleFiltersOrdering2() throws InterruptedException, SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(2).when(mockInitContext).getVertexNumTasks("v1");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 2));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent event =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      event.setSourceVertexName("v1");
+
+      pruner.processVertex("v1");
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testMultipleSourcesOrdering1() throws InterruptedException, SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(2).when(mockInitContext).getVertexNumTasks("v1");
+    doReturn(3).when(mockInitContext).getVertexNumTasks("v2");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent eventV1 =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      eventV1.setSourceVertexName("v1");
+
+      InputInitializerEvent eventV2 =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      eventV2.setSourceVertexName("v2");
+
+      // 2 X 2 events for V1. 3 X 1 events for V2
+
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV2);
+      pruner.addEvent(eventV2);
+      pruner.addEvent(eventV2);
+      pruner.processVertex("v1");
+      pruner.processVertex("v2");
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testMultipleSourcesOrdering2() throws InterruptedException, SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(2).when(mockInitContext).getVertexNumTasks("v1");
+    doReturn(3).when(mockInitContext).getVertexNumTasks("v2");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent eventV1 =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      eventV1.setSourceVertexName("v1");
+
+      InputInitializerEvent eventV2 =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      eventV2.setSourceVertexName("v2");
+
+      // 2 X 2 events for V1. 3 X 1 events for V2
+
+      pruner.processVertex("v1");
+      pruner.processVertex("v2");
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV2);
+      pruner.addEvent(eventV2);
+      pruner.addEvent(eventV2);
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000)
+  public void testMultipleSourcesOrdering3() throws InterruptedException, SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(2).when(mockInitContext).getVertexNumTasks("v1");
+    doReturn(3).when(mockInitContext).getVertexNumTasks("v2");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent eventV1 =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      eventV1.setSourceVertexName("v1");
+
+      InputInitializerEvent eventV2 =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      eventV2.setSourceVertexName("v2");
+
+      // 2 X 2 events for V1. 3 X 1 events for V2
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.processVertex("v1");
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV1);
+      pruner.addEvent(eventV2);
+      pruner.processVertex("v2");
+      pruner.addEvent(eventV2);
+      pruner.addEvent(eventV2);
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 5000, expected = IllegalStateException.class)
+  public void testExtraEvents() throws InterruptedException, IOException, HiveException,
+      SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(1).when(mockInitContext).getVertexNumTasks("v1");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent event =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      event.setSourceVertexName("v1");
+
+      pruner.addEvent(event);
+      pruner.addEvent(event);
+      pruner.processVertex("v1");
+
+      pruneRunnable.awaitEnd();
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  @Test(timeout = 20000)
+  public void testMissingEvent() throws InterruptedException, IOException, HiveException,
+      SerDeException {
+    InputInitializerContext mockInitContext = mock(InputInitializerContext.class);
+    doReturn(1).when(mockInitContext).getVertexNumTasks("v1");
+
+    MapWork mapWork = createMockMapWork(new TestSource("v1", 1));
+    DynamicPartitionPruner pruner =
+        new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork);
+
+
+    PruneRunnable pruneRunnable = new PruneRunnable(pruner);
+    Thread t = new Thread(pruneRunnable);
+    t.start();
+    try {
+      pruneRunnable.start();
+
+      InputInitializerEvent event =
+          InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0));
+      event.setSourceVertexName("v1");
+
+      pruner.processVertex("v1");
+      Thread.sleep(3000l);
+      // The pruner should not have completed.
+      assertFalse(pruneRunnable.ended.get());
+      assertFalse(pruneRunnable.inError.get());
+    } finally {
+      t.interrupt();
+      t.join();
+    }
+  }
+
+  private static class PruneRunnable implements Runnable {
+
+    final DynamicPartitionPruner pruner;
+    final ReentrantLock lock = new ReentrantLock();
+    final Condition endCondition = lock.newCondition();
+    final Condition startCondition = lock.newCondition();
+    final AtomicBoolean started = new AtomicBoolean(false);
+    final AtomicBoolean ended = new AtomicBoolean(false);
+    final AtomicBoolean inError = new AtomicBoolean(false);
+
+    private PruneRunnable(DynamicPartitionPruner pruner) {
+      this.pruner = pruner;
+    }
+
+    void start() {
+      started.set(true);
+      lock.lock();
+      try {
+        startCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    void awaitEnd() throws InterruptedException {
+      lock.lock();
+      try {
+        while (!ended.get()) {
+          endCondition.await();
+        }
+      } finally {
+        lock.unlock();
+      }
+    }
+
+    @Override
+    public void run() {
+      try {
+        lock.lock();
+        try {
+          while (!started.get()) {
+            startCondition.await();
+          }
+        } finally {
+          lock.unlock();
+        }
+
+        pruner.prune();
+        lock.lock();
+        try {
+          ended.set(true);
+          endCondition.signal();
+        } finally {
+          lock.unlock();
+        }
+      } catch (SerDeException | IOException | InterruptedException | HiveException e) {
+        inError.set(true);
+      }
+    }
+  }
+
+
+  private MapWork createMockMapWork(TestSource... testSources) {
+    MapWork mapWork = mock(MapWork.class);
+
+    Map<String, List<TableDesc>> tableMap = new HashMap<>();
+    Map<String, List<String>> columnMap = new HashMap<>();
+    Map<String, List<ExprNodeDesc>> exprMap = new HashMap<>();
+
+    int count = 0;
+    for (TestSource testSource : testSources) {
+
+      for (int i = 0; i < testSource.numExpressions; i++) {
+        List<TableDesc> tableDescList = tableMap.get(testSource.vertexName);
+        if (tableDescList == null) {
+          tableDescList = new LinkedList<>();
+          tableMap.put(testSource.vertexName, tableDescList);
+        }
+        tableDescList.add(mock(TableDesc.class));
+
+        List<String> columnList = columnMap.get(testSource.vertexName);
+        if (columnList == null) {
+          columnList = new LinkedList<>();
+          columnMap.put(testSource.vertexName, columnList);
+        }
+        columnList.add(testSource.vertexName + "c_" + count + "_" + i);
+
+        List<ExprNodeDesc> exprNodeDescList = exprMap.get(testSource.vertexName);
+        if (exprNodeDescList == null) {
+          exprNodeDescList = new LinkedList<>();
+          exprMap.put(testSource.vertexName, exprNodeDescList);
+        }
+        exprNodeDescList.add(mock(ExprNodeDesc.class));
+      }
+
+      count++;
+    }
+
+    doReturn(tableMap).when(mapWork).getEventSourceTableDescMap();
+    doReturn(columnMap).when(mapWork).getEventSourceColumnNameMap();
+    doReturn(exprMap).when(mapWork).getEventSourcePartKeyExprMap();
+    return mapWork;
+  }
+
+  private static class TestSource {
+    String vertexName;
+    int numExpressions;
+
+    public TestSource(String vertexName, int numExpressions) {
+      this.vertexName = vertexName;
+      this.numExpressions = numExpressions;
+    }
+  }
+
+  private static class DynamicPartitionPrunerForEventTesting extends DynamicPartitionPruner {
+
+
+    public DynamicPartitionPrunerForEventTesting(
+        InputInitializerContext context, MapWork work) throws SerDeException {
+      super(context, work, new JobConf());
+    }
+
+    @Override
+    protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName,
+                                          JobConf jobConf) throws
+        SerDeException {
+      return new SourceInfo(t, partKeyExpr, columnName, jobConf, null);
+    }
+
+    @Override
+    protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException,
+        IOException {
+      // No-op: testing events only
+      return sourceName;
+    }
+
+    @Override
+    protected void prunePartitionSingleSource(String source, SourceInfo si)
+        throws HiveException {
+      // No-op: testing events only
+    }
+  }
+}



Mime
View raw message