beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [1/2] incubator-beam git commit: Add control of PipelineVisitor recursion into composite transforms
Date Mon, 09 May 2016 14:52:31 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/master 07c60a965 -> 03e99540a


Add control of PipelineVisitor recursion into composite transforms


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/dbf7a06a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/dbf7a06a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/dbf7a06a

Branch: refs/heads/master
Commit: dbf7a06a7c901cda065e49914932cf0be5d6db4e
Parents: 07c60a9
Author: Kenneth Knowles <klk@google.com>
Authored: Wed Apr 20 12:28:50 2016 -0700
Committer: Kenneth Knowles <klk@google.com>
Committed: Mon May 9 07:29:49 2016 -0700

----------------------------------------------------------------------
 .../direct/ConsumerTrackingPipelineVisitor.java |  7 +-
 .../direct/KeyedPValueTrackingVisitor.java      |  5 +-
 .../FlinkBatchPipelineTranslator.java           | 83 +++++++++-----------
 .../translation/FlinkPipelineTranslator.java    |  2 +-
 .../FlinkStreamingPipelineTranslator.java       | 35 ++-------
 .../dataflow/DataflowPipelineRunner.java        |  5 +-
 .../dataflow/DataflowPipelineTranslator.java    |  7 +-
 .../dataflow/DataflowPipelineRunnerTest.java    | 18 +----
 .../beam/runners/spark/SparkPipelineRunner.java | 48 ++---------
 .../main/java/org/apache/beam/sdk/Pipeline.java | 35 ++++++++-
 .../runners/AggregatorPipelineExtractor.java    | 10 +--
 .../beam/sdk/runners/DirectPipelineRunner.java  | 12 +--
 .../sdk/runners/RecordingPipelineVisitor.java   | 12 +--
 .../beam/sdk/runners/TransformTreeNode.java     | 16 ++--
 .../AggregatorPipelineExtractorTest.java        |  2 +-
 .../beam/sdk/runners/TransformTreeTest.java     | 12 +--
 16 files changed, 124 insertions(+), 185 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java
index c790463..3300723 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitor.java
@@ -41,7 +41,7 @@ import java.util.Set;
  * {@link Pipeline}. This is used to schedule consuming {@link PTransform PTransforms} to
consume
  * input after the upstream transform has produced and committed output.
  */
-public class ConsumerTrackingPipelineVisitor implements PipelineVisitor {
+public class ConsumerTrackingPipelineVisitor extends PipelineVisitor.Defaults {
   private Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers
= new HashMap<>();
   private Collection<AppliedPTransform<?, ?, ?>> rootTransforms = new ArrayList<>();
   private Collection<PCollectionView<?>> views = new ArrayList<>();
@@ -51,13 +51,14 @@ public class ConsumerTrackingPipelineVisitor implements PipelineVisitor
{
   private boolean finalized = false;
 
   @Override
-  public void enterCompositeTransform(TransformTreeNode node) {
+  public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
     checkState(
         !finalized,
         "Attempting to traverse a pipeline (node %s) with a %s "
             + "which has already visited a Pipeline and is finalized",
         node.getFullName(),
         ConsumerTrackingPipelineVisitor.class.getSimpleName());
+    return CompositeBehavior.ENTER_TRANSFORM;
   }
 
   @Override
@@ -73,7 +74,7 @@ public class ConsumerTrackingPipelineVisitor implements PipelineVisitor
{
   }
 
   @Override
-  public void visitTransform(TransformTreeNode node) {
+  public void visitPrimitiveTransform(TransformTreeNode node) {
     toFinalize.removeAll(node.getInput().expand());
     AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(node);
     stepNames.put(appliedTransform, genStepName());

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
index b7c755e..2fea00a 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java
@@ -56,12 +56,13 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
   }
 
   @Override
-  public void enterCompositeTransform(TransformTreeNode node) {
+  public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
     checkState(
         !finalized,
         "Attempted to use a %s that has already been finalized on a pipeline (visiting node
%s)",
         KeyedPValueTrackingVisitor.class.getSimpleName(),
         node);
+    return CompositeBehavior.ENTER_TRANSFORM;
   }
 
   @Override
@@ -79,7 +80,7 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor {
   }
 
   @Override
-  public void visitTransform(TransformTreeNode node) {}
+  public void visitPrimitiveTransform(TransformTreeNode node) {}
 
   @Override
   public void visitValue(PValue value, TransformTreeNode producer) {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java
index 456cf09..3d39e81 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchPipelineTranslator.java
@@ -43,11 +43,6 @@ public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator
{
 
   private int depth = 0;
 
-  /**
-   * Composite transform that we want to translate before proceeding with other transforms.
-   */
-  private PTransform<?, ?> currentCompositeTransform;
-
   public FlinkBatchPipelineTranslator(ExecutionEnvironment env, PipelineOptions options)
{
     this.batchContext = new FlinkBatchTranslationContext(env, options);
   }
@@ -57,54 +52,33 @@ public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator
{
   // --------------------------------------------------------------------------------------------
 
   @Override
-  public void enterCompositeTransform(TransformTreeNode node) {
+  public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
     LOG.info(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node));
 
-    PTransform<?, ?> transform = node.getTransform();
-    if (transform != null && currentCompositeTransform == null) {
-
-      BatchTransformTranslator<?> translator = FlinkBatchTransformTranslators.getTranslator(transform);
-      if (translator != null) {
-        currentCompositeTransform = transform;
-        if (transform instanceof CoGroupByKey && node.getInput().expand().size()
!= 2) {
-          // we can only optimize CoGroupByKey for input size 2
-          currentCompositeTransform = null;
-        }
-      }
+    BatchTransformTranslator<?> translator = getTranslator(node);
+
+    if (translator != null) {
+      applyBatchTransform(node.getTransform(), node, translator);
+      LOG.info(genSpaces(this.depth) + "translated-" + formatNodeName(node));
+      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+    } else {
+      this.depth++;
+      return CompositeBehavior.ENTER_TRANSFORM;
     }
-    this.depth++;
   }
 
   @Override
   public void leaveCompositeTransform(TransformTreeNode node) {
-    PTransform<?, ?> transform = node.getTransform();
-    if (transform != null && currentCompositeTransform == transform) {
-
-      BatchTransformTranslator<?> translator = FlinkBatchTransformTranslators.getTranslator(transform);
-      if (translator != null) {
-        LOG.info(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node));
-        applyBatchTransform(transform, node, translator);
-        currentCompositeTransform = null;
-      } else {
-        throw new IllegalStateException("Attempted to translate composite transform " +
-            "but no translator was found: " + currentCompositeTransform);
-      }
-    }
     this.depth--;
     LOG.info(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node));
   }
 
   @Override
-  public void visitTransform(TransformTreeNode node) {
-    LOG.info(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node));
-    if (currentCompositeTransform != null) {
-      // ignore it
-      return;
-    }
+  public void visitPrimitiveTransform(TransformTreeNode node) {
+    LOG.info(genSpaces(this.depth) + "visitPrimitiveTransform- " + formatNodeName(node));
 
-    // get the transformation corresponding to hte node we are
+    // get the transformation corresponding to the node we are
     // currently visiting and translate it into its Flink alternative.
-
     PTransform<?, ?> transform = node.getTransform();
     BatchTransformTranslator<?> translator = FlinkBatchTransformTranslators.getTranslator(transform);
     if (translator == null) {
@@ -114,11 +88,6 @@ public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator
{
     applyBatchTransform(transform, node, translator);
   }
 
-  @Override
-  public void visitValue(PValue value, TransformTreeNode producer) {
-    // do nothing here
-  }
-
   private <T extends PTransform<?, ?>> void applyBatchTransform(PTransform<?,
?> transform, TransformTreeNode node, BatchTransformTranslator<?> translator) {
 
     @SuppressWarnings("unchecked")
@@ -140,6 +109,32 @@ public class FlinkBatchPipelineTranslator extends FlinkPipelineTranslator
{
     void translateNode(Type transform, FlinkBatchTranslationContext context);
   }
 
+  /**
+   * Returns a translator for the given node, if it is possible, otherwise null.
+   */
+  private static BatchTransformTranslator<?> getTranslator(TransformTreeNode node)
{
+    PTransform<?, ?> transform = node.getTransform();
+
+    // Root of the graph is null
+    if (transform == null) {
+      return null;
+    }
+
+    BatchTransformTranslator<?> translator = FlinkBatchTransformTranslators.getTranslator(transform);
+
+    // No translator known
+    if (translator == null) {
+      return null;
+    }
+
+    // We actually only specialize CoGroupByKey when exactly 2 inputs
+    if (transform instanceof CoGroupByKey && node.getInput().expand().size() != 2)
{
+      return null;
+    }
+
+    return translator;
+  }
+
   private static String genSpaces(int n) {
     String s = "";
     for (int i = 0; i < n; i++) {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java
index 82d23b0..46e5712 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkPipelineTranslator.java
@@ -28,7 +28,7 @@ import org.apache.beam.sdk.Pipeline;
  * a {@link org.apache.flink.streaming.api.datastream.DataStream} (for streaming) or a
  * {@link org.apache.flink.api.java.DataSet} (for batch) one.
  */
-public abstract class FlinkPipelineTranslator implements Pipeline.PipelineVisitor {
+public abstract class FlinkPipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
 
   public void translate(Pipeline pipeline) {
     pipeline.traverseTopologically(this);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java
index ebaf6ba..31b2bee 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingPipelineTranslator.java
@@ -43,9 +43,6 @@ public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator
{
 
   private int depth = 0;
 
-  /** Composite transform that we want to translate before proceeding with other transforms.
*/
-  private PTransform<?, ?> currentCompositeTransform;
-
   public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment env, PipelineOptions
options) {
     this.streamingContext = new FlinkStreamingTranslationContext(env, options);
   }
@@ -55,47 +52,31 @@ public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator
{
   // --------------------------------------------------------------------------------------------
 
   @Override
-  public void enterCompositeTransform(TransformTreeNode node) {
+  public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
     LOG.info(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node));
 
     PTransform<?, ?> transform = node.getTransform();
-    if (transform != null && currentCompositeTransform == null) {
-
+    if (transform != null) {
       StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform);
       if (translator != null) {
-        currentCompositeTransform = transform;
+        applyStreamingTransform(transform, node, translator);
+        LOG.info(genSpaces(this.depth) + "translated-" + formatNodeName(node));
+        return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
       }
     }
     this.depth++;
+    return CompositeBehavior.ENTER_TRANSFORM;
   }
 
   @Override
   public void leaveCompositeTransform(TransformTreeNode node) {
-    PTransform<?, ?> transform = node.getTransform();
-    if (transform != null && currentCompositeTransform == transform) {
-
-      StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform);
-      if (translator != null) {
-        LOG.info(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node));
-        applyStreamingTransform(transform, node, translator);
-        currentCompositeTransform = null;
-      } else {
-        throw new IllegalStateException("Attempted to translate composite transform " +
-            "but no translator was found: " + currentCompositeTransform);
-      }
-    }
     this.depth--;
     LOG.info(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node));
   }
 
   @Override
-  public void visitTransform(TransformTreeNode node) {
-    LOG.info(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node));
-    if (currentCompositeTransform != null) {
-      // ignore it
-      return;
-    }
-
+  public void visitPrimitiveTransform(TransformTreeNode node) {
+    LOG.info(genSpaces(this.depth) + "visitPrimitiveTransform- " + formatNodeName(node));
     // get the transformation corresponding to hte node we are
     // currently visiting and translate it into its Flink alternative.
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java
index 41b4df7..4076802 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineRunner.java
@@ -680,17 +680,18 @@ public class DataflowPipelineRunner extends PipelineRunner<DataflowPipelineJob>
         }
 
         @Override
-        public void visitTransform(TransformTreeNode node) {
+        public void visitPrimitiveTransform(TransformTreeNode node) {
           if (ptransformViewsWithNonDeterministicKeyCoders.contains(node.getTransform()))
{
             ptransformViewNamesWithNonDeterministicKeyCoders.add(node.getFullName());
           }
         }
 
         @Override
-        public void enterCompositeTransform(TransformTreeNode node) {
+        public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
           if (ptransformViewsWithNonDeterministicKeyCoders.contains(node.getTransform()))
{
             ptransformViewNamesWithNonDeterministicKeyCoders.add(node.getFullName());
           }
+          return CompositeBehavior.ENTER_TRANSFORM;
         }
 
         @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index 4ef1bdb..05879d9 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -348,7 +348,7 @@ public class DataflowPipelineTranslator {
   /**
    * Translates a Pipeline into the Dataflow representation.
    */
-  class Translator implements PipelineVisitor, TranslationContext {
+  class Translator extends PipelineVisitor.Defaults implements TranslationContext {
     /** The Pipeline to translate. */
     private final Pipeline pipeline;
 
@@ -493,16 +493,13 @@ public class DataflowPipelineTranslator {
       return currentTransform;
     }
 
-    @Override
-    public void enterCompositeTransform(TransformTreeNode node) {
-    }
 
     @Override
     public void leaveCompositeTransform(TransformTreeNode node) {
     }
 
     @Override
-    public void visitTransform(TransformTreeNode node) {
+    public void visitPrimitiveTransform(TransformTreeNode node) {
       PTransform<?, ?> transform = node.getTransform();
       TransformTranslator translator =
           getTransformTranslator(transform.getClass());

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java
index d4d4b3b..2993c50 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineRunnerTest.java
@@ -84,7 +84,6 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
@@ -820,26 +819,15 @@ public class DataflowPipelineRunnerTest {
   }
 
   /** Records all the composite transforms visited within the Pipeline. */
-  private static class CompositeTransformRecorder implements PipelineVisitor {
+  private static class CompositeTransformRecorder extends PipelineVisitor.Defaults {
     private List<PTransform<?, ?>> transforms = new ArrayList<>();
 
     @Override
-    public void enterCompositeTransform(TransformTreeNode node) {
+    public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
       if (node.getTransform() != null) {
         transforms.add(node.getTransform());
       }
-    }
-
-    @Override
-    public void leaveCompositeTransform(TransformTreeNode node) {
-    }
-
-    @Override
-    public void visitTransform(TransformTreeNode node) {
-    }
-
-    @Override
-    public void visitValue(PValue value, TransformTreeNode producer) {
+      return CompositeBehavior.ENTER_TRANSFORM;
     }
 
     public List<PTransform<?, ?>> getCompositeTransforms() {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
index bae4e53..af5acf1 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineRunner.java
@@ -41,7 +41,6 @@ import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
-import org.apache.beam.sdk.values.PValue;
 
 import org.apache.spark.SparkException;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -219,7 +218,7 @@ public final class SparkPipelineRunner extends PipelineRunner<EvaluationResult>
   /**
    * Evaluator on the pipeline.
    */
-  public abstract static class Evaluator implements Pipeline.PipelineVisitor {
+  public abstract static class Evaluator extends Pipeline.PipelineVisitor.Defaults {
     protected static final Logger LOG = LoggerFactory.getLogger(Evaluator.class);
 
     protected final SparkPipelineTranslator translator;
@@ -228,62 +227,29 @@ public final class SparkPipelineRunner extends PipelineRunner<EvaluationResult>
       this.translator = translator;
     }
 
-    // Set upon entering a composite node which can be directly mapped to a single
-    // TransformEvaluator.
-    private TransformTreeNode currentTranslatedCompositeNode;
-
-    /**
-     * If true, we're currently inside a subtree of a composite node which directly maps
to a
-     * single
-     * TransformEvaluator; children nodes are ignored, and upon post-visiting the translated
-     * composite node, the associated TransformEvaluator will be visited.
-     */
-    private boolean inTranslatedCompositeNode() {
-      return currentTranslatedCompositeNode != null;
-    }
-
     @Override
-    public void enterCompositeTransform(TransformTreeNode node) {
-      if (!inTranslatedCompositeNode() && node.getTransform() != null) {
+    public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
+      if (node.getTransform() != null) {
         @SuppressWarnings("unchecked")
         Class<PTransform<?, ?>> transformClass =
             (Class<PTransform<?, ?>>) node.getTransform().getClass();
         if (translator.hasTranslation(transformClass)) {
           LOG.info("Entering directly-translatable composite transform: '{}'", node.getFullName());
           LOG.debug("Composite transform class: '{}'", transformClass);
-          currentTranslatedCompositeNode = node;
+          doVisitTransform(node);
+          return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
         }
       }
+      return CompositeBehavior.ENTER_TRANSFORM;
     }
 
     @Override
-    public void leaveCompositeTransform(TransformTreeNode node) {
-      // NB: We depend on enterCompositeTransform and leaveCompositeTransform providing 'node'
-      // objects for which Object.equals() returns true iff they are the same logical node
-      // within the tree.
-      if (inTranslatedCompositeNode() && node.equals(currentTranslatedCompositeNode))
{
-        LOG.info("Post-visiting directly-translatable composite transform: '{}'",
-                node.getFullName());
-        doVisitTransform(node);
-        currentTranslatedCompositeNode = null;
-      }
-    }
-
-    @Override
-    public void visitTransform(TransformTreeNode node) {
-      if (inTranslatedCompositeNode()) {
-        LOG.info("Skipping '{}'; already in composite transform.", node.getFullName());
-        return;
-      }
+    public void visitPrimitiveTransform(TransformTreeNode node) {
       doVisitTransform(node);
     }
 
     protected abstract <TransformT extends PTransform<? super PInput, POutput>>
void
         doVisitTransform(TransformTreeNode node);
-
-    @Override
-    public void visitValue(PValue value, TransformTreeNode producer) {
-    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index 65a0755..4e7e63f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -220,8 +220,10 @@ public class Pipeline {
     /**
      * Called for each composite transform after all topological predecessors have been visited
      * but before any of its component transforms.
+     *
+     * <p>The return value controls whether or not child transforms are visited.
      */
-    public void enterCompositeTransform(TransformTreeNode node);
+    public CompositeBehavior enterCompositeTransform(TransformTreeNode node);
 
     /**
      * Called for each composite transform after all of its component transforms and their
outputs
@@ -233,13 +235,42 @@ public class Pipeline {
      * Called for each primitive transform after all of its topological predecessors
      * and inputs have been visited.
      */
-    public void visitTransform(TransformTreeNode node);
+    public void visitPrimitiveTransform(TransformTreeNode node);
 
     /**
      * Called for each value after the transform that produced the value has been
      * visited.
      */
     public void visitValue(PValue value, TransformTreeNode producer);
+
+    /**
+     * Control enum for indicating whether or not a traversal should process the contents
of
+     * a composite transform or not.
+     */
+    public enum CompositeBehavior {
+      ENTER_TRANSFORM,
+      DO_NOT_ENTER_TRANSFORM;
+    }
+
+    /**
+     * Default no-op {@link PipelineVisitor} that enters all composite transforms.
+     * User implementations can override just those methods they are interested in.
+     */
+    public class Defaults implements PipelineVisitor {
+      @Override
+      public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
+        return CompositeBehavior.ENTER_TRANSFORM;
+      }
+
+      @Override
+      public void leaveCompositeTransform(TransformTreeNode node) { }
+
+      @Override
+      public void visitPrimitiveTransform(TransformTreeNode node) { }
+
+      @Override
+      public void visitValue(PValue value, TransformTreeNode producer) { }
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
index 86a851f..146ddfa 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java
@@ -56,7 +56,7 @@ public class AggregatorPipelineExtractor {
     return aggregatorSteps.asMap();
   }
 
-  private static class AggregatorVisitor implements PipelineVisitor {
+  private static class AggregatorVisitor extends PipelineVisitor.Defaults {
     private final SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps;
 
     public AggregatorVisitor(SetMultimap<Aggregator<?, ?>, PTransform<?, ?>>
aggregatorSteps) {
@@ -64,13 +64,7 @@ public class AggregatorPipelineExtractor {
     }
 
     @Override
-    public void enterCompositeTransform(TransformTreeNode node) {}
-
-    @Override
-    public void leaveCompositeTransform(TransformTreeNode node) {}
-
-    @Override
-    public void visitTransform(TransformTreeNode node) {
+    public void visitPrimitiveTransform(TransformTreeNode node) {
       PTransform<?, ?> transform = node.getTransform();
       addStepToAggregators(transform, getAggregators(transform));
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java
index 3cb9703..590ce6f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/DirectPipelineRunner.java
@@ -828,7 +828,7 @@ public class DirectPipelineRunner
 
   /////////////////////////////////////////////////////////////////////////////
 
-  class Evaluator implements PipelineVisitor, EvaluationContext {
+  class Evaluator extends PipelineVisitor.Defaults implements EvaluationContext {
     /**
      * A map from PTransform to the step name of that transform. This is the internal name
for the
      * transform (e.g. "s2").
@@ -881,15 +881,7 @@ public class DirectPipelineRunner
     }
 
     @Override
-    public void enterCompositeTransform(TransformTreeNode node) {
-    }
-
-    @Override
-    public void leaveCompositeTransform(TransformTreeNode node) {
-    }
-
-    @Override
-    public void visitTransform(TransformTreeNode node) {
+    public void visitPrimitiveTransform(TransformTreeNode node) {
       PTransform<?, ?> transform = node.getTransform();
       fullNames.put(transform, node.getFullName());
       TransformEvaluator evaluator =

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java
index 84df5fd..d64738f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/RecordingPipelineVisitor.java
@@ -30,21 +30,13 @@ import java.util.List;
  *
  * <p>Provided for internal unit tests.
  */
-public class RecordingPipelineVisitor implements Pipeline.PipelineVisitor {
+public class RecordingPipelineVisitor extends Pipeline.PipelineVisitor.Defaults {
 
   public final List<PTransform<?, ?>> transforms = new ArrayList<>();
   public final List<PValue> values = new ArrayList<>();
 
   @Override
-  public void enterCompositeTransform(TransformTreeNode node) {
-  }
-
-  @Override
-  public void leaveCompositeTransform(TransformTreeNode node) {
-  }
-
-  @Override
-  public void visitTransform(TransformTreeNode node) {
+  public void visitPrimitiveTransform(TransformTreeNode node) {
     transforms.add(node.getTransform());
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java
index a6efc51..59edd52 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformTreeNode.java
@@ -17,7 +17,8 @@
  */
 package org.apache.beam.sdk.runners;
 
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
@@ -198,7 +199,7 @@ public class TransformTreeNode {
    * transform (or child nodes for composite transforms), then the
    * output values.
    */
-  public void visit(Pipeline.PipelineVisitor visitor,
+  public void visit(PipelineVisitor visitor,
                     Set<PValue> visitedValues) {
     if (!finishedSpecifying) {
       finishSpecifying();
@@ -212,13 +213,16 @@ public class TransformTreeNode {
     }
 
     if (isCompositeNode()) {
-      visitor.enterCompositeTransform(this);
-      for (TransformTreeNode child : parts) {
-        child.visit(visitor, visitedValues);
+      PipelineVisitor.CompositeBehavior recurse = visitor.enterCompositeTransform(this);
+
+      if (recurse.equals(CompositeBehavior.ENTER_TRANSFORM)) {
+        for (TransformTreeNode child : parts) {
+          child.visit(visitor, visitedValues);
+        }
       }
       visitor.leaveCompositeTransform(this);
     } else {
-      visitor.visitTransform(this);
+      visitor.visitPrimitiveTransform(this);
     }
 
     // Visit outputs.

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
index 7950a9e..74cc5e0 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java
@@ -205,7 +205,7 @@ public class AggregatorPipelineExtractorTest {
     public Object answer(InvocationOnMock invocation) throws Throwable {
       PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0];
       for (TransformTreeNode node : nodes) {
-        visitor.visitTransform(node);
+        visitor.visitPrimitiveTransform(node);
       }
       return null;
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/dbf7a06a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
index e4eb204..aecebd7 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformTreeTest.java
@@ -40,7 +40,6 @@ import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PDone;
-import org.apache.beam.sdk.values.PValue;
 
 import org.junit.Rule;
 import org.junit.Test;
@@ -128,9 +127,9 @@ public class TransformTreeTest {
     final EnumSet<TransformsSeen> left =
         EnumSet.noneOf(TransformsSeen.class);
 
-    p.traverseTopologically(new Pipeline.PipelineVisitor() {
+    p.traverseTopologically(new Pipeline.PipelineVisitor.Defaults() {
       @Override
-      public void enterCompositeTransform(TransformTreeNode node) {
+      public CompositeBehavior enterCompositeTransform(TransformTreeNode node) {
         PTransform<?, ?> transform = node.getTransform();
         if (transform instanceof Sample.SampleAny) {
           assertTrue(visited.add(TransformsSeen.SAMPLE_ANY));
@@ -142,6 +141,7 @@ public class TransformTreeTest {
           assertTrue(node.isCompositeNode());
         }
         assertThat(transform, not(instanceOf(Read.Bounded.class)));
+        return CompositeBehavior.ENTER_TRANSFORM;
       }
 
       @Override
@@ -153,7 +153,7 @@ public class TransformTreeTest {
       }
 
       @Override
-      public void visitTransform(TransformTreeNode node) {
+      public void visitPrimitiveTransform(TransformTreeNode node) {
         PTransform<?, ?> transform = node.getTransform();
         // Pick is a composite, should not be visited here.
         assertThat(transform, not(instanceOf(Sample.SampleAny.class)));
@@ -163,10 +163,6 @@ public class TransformTreeTest {
           assertTrue(visited.add(TransformsSeen.READ));
         }
       }
-
-      @Override
-      public void visitValue(PValue value, TransformTreeNode producer) {
-      }
     });
 
     assertTrue(visited.equals(EnumSet.allOf(TransformsSeen.class)));



Mime
View raw message