beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tg...@apache.org
Subject [2/2] beam git commit: Remove PipelineRunner#apply
Date Wed, 01 Mar 2017 21:06:02 GMT
Remove PipelineRunner#apply

All existing Pipeline Runners that use the Java SDK modify Pipeline
graphs with the Pipeline Surgery APIs. Apply is now superflous.

Add an AssertionCountingVisitor to enable TestRunners to track the
number of assertions in the Pipeline.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/3408f604
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/3408f604
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/3408f604

Branch: refs/heads/master
Commit: 3408f6049ba3692f9edbbeead75626125954d4b6
Parents: a81c457
Author: Thomas Groh <tgroh@google.com>
Authored: Thu Feb 23 17:32:01 2017 -0800
Committer: Thomas Groh <tgroh@google.com>
Committed: Wed Mar 1 13:05:49 2017 -0800

----------------------------------------------------------------------
 .../beam/runners/apex/TestApexRunner.java       | 10 ---
 .../beam/runners/flink/TestFlinkRunner.java     |  9 ---
 .../dataflow/testing/TestDataflowRunner.java    | 17 ++---
 .../testing/TestDataflowRunnerTest.java         |  3 +-
 runners/spark/pom.xml                           |  4 ++
 .../beam/runners/spark/TestSparkRunner.java     | 76 ++++++++++++--------
 .../beam/runners/spark/ForceStreamingTest.java  |  2 +
 .../main/java/org/apache/beam/sdk/Pipeline.java |  4 +-
 .../apache/beam/sdk/runners/PipelineRunner.java | 14 ----
 .../org/apache/beam/sdk/testing/PAssert.java    | 64 +++++++++++++++++
 .../apache/beam/sdk/testing/PAssertTest.java    | 30 ++++++++
 11 files changed, 154 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/apex/src/main/java/org/apache/beam/runners/apex/TestApexRunner.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/TestApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/TestApexRunner.java
index e447e37..a64ac54 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/TestApexRunner.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/TestApexRunner.java
@@ -18,14 +18,10 @@
 package org.apache.beam.runners.apex;
 
 import java.io.IOException;
-
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PipelineRunner;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
 import org.joda.time.Duration;
 
 /**
@@ -49,12 +45,6 @@ public class TestApexRunner extends PipelineRunner<ApexRunnerResult>
{
   }
 
   @Override
-  public <OutputT extends POutput, InputT extends PInput>
-      OutputT apply(PTransform<InputT, OutputT> transform, InputT input) {
-    return delegate.apply(transform, input);
-  }
-
-  @Override
   public ApexRunnerResult run(Pipeline pipeline) {
     ApexRunnerResult result = delegate.run(pipeline);
     try {

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkRunner.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkRunner.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkRunner.java
index 30a94a1..ef56b55 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkRunner.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/TestFlinkRunner.java
@@ -24,10 +24,7 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PipelineRunner;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.UserCodeException;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
 
 /**
  * Test Flink runner.
@@ -56,12 +53,6 @@ public class TestFlinkRunner extends PipelineRunner<PipelineResult>
{
   }
 
   @Override
-  public <OutputT extends POutput, InputT extends PInput>
-      OutputT apply(PTransform<InputT, OutputT> transform, InputT input) {
-    return delegate.apply(transform, input);
-  }
-
-  @Override
   public PipelineResult run(Pipeline pipeline) {
     try {
       return delegate.run(pipeline);

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java
index 0564448..5315671 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java
@@ -46,9 +46,6 @@ import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestPipelineOptions;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
 import org.joda.time.Duration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -102,6 +99,7 @@ public class TestDataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
   }
 
   DataflowPipelineJob run(Pipeline pipeline, DataflowRunner runner) {
+    updatePAssertCount(pipeline);
 
     TestPipelineOptions testPipelineOptions = pipeline.getOptions().as(TestPipelineOptions.class);
     final DataflowPipelineJob job;
@@ -183,16 +181,9 @@ public class TestDataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
     return job;
   }
 
-  @Override
-  public <OutputT extends POutput, InputT extends PInput> OutputT apply(
-      PTransform<InputT, OutputT> transform, InputT input) {
-    if (transform instanceof PAssert.OneSideInputAssert
-        || transform instanceof PAssert.GroupThenAssert
-        || transform instanceof PAssert.GroupThenAssertForSingleton) {
-      expectedNumberOfAssertions += 1;
-    }
-
-    return runner.apply(transform, input);
+  @VisibleForTesting
+  void updatePAssertCount(Pipeline pipeline) {
+    expectedNumberOfAssertions = PAssert.countAsserts(pipeline);
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java
index da5630b..1e906d2 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java
@@ -378,7 +378,8 @@ public class TestDataflowRunnerTest {
     PCollection<Integer> pc = p.apply(Create.of(1, 2, 3));
     PAssert.that(pc).containsInAnyOrder(1, 2, 3);
 
-    TestDataflowRunner runner = (TestDataflowRunner) p.getRunner();
+    TestDataflowRunner runner = TestDataflowRunner.fromOptions(options);
+    runner.updatePAssertCount(p);
     doReturn(State.RUNNING).when(job).getState();
     JobMetrics metrics = buildJobMetrics(
         generateMockMetrics(true /* success */, false /* tentative */));

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/spark/pom.xml
----------------------------------------------------------------------
diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
index 8c35178..44f20cc 100644
--- a/runners/spark/pom.xml
+++ b/runners/spark/pom.xml
@@ -226,6 +226,10 @@
     </dependency>
     <dependency>
       <groupId>org.apache.beam</groupId>
+      <artifactId>beam-runners-core-construction-java</artifactId>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.beam</groupId>
       <artifactId>beam-runners-core-java</artifactId>
       <exclusions>
         <!-- Use Hadoop/Spark's backend logger instead of jdk14 for tests -->

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index 035da00..16ddc9e 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -22,9 +22,14 @@ import static com.google.common.base.Preconditions.checkNotNull;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.is;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.io.File;
 import java.io.IOException;
+import java.util.List;
+import java.util.Map;
 import org.apache.beam.runners.core.UnboundedReadFromBoundedSource;
+import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
 import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
@@ -32,6 +37,7 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -39,14 +45,13 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.util.ValueWithRecordId;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.commons.io.FileUtils;
 import org.joda.time.Duration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-
 /**
  * The SparkRunner translate operations defined on a pipeline to a representation executable
  * by Spark, and then submitting the job to Spark to be executed. If we wanted to run a Beam
@@ -74,7 +79,6 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
 
   private SparkRunner delegate;
   private boolean isForceStreaming;
-  private int expectedNumberOfAssertions = 0;
 
   private TestSparkRunner(TestSparkPipelineOptions options) {
     this.delegate = SparkRunner.fromOptions(options);
@@ -88,37 +92,22 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
     return new TestSparkRunner(sparkOptions);
   }
 
-  /**
-   * Overrides for the test runner.
-   */
-  @SuppressWarnings("unchecked")
   @Override
-  public <OutputT extends POutput, InputT extends PInput> OutputT apply(
-      PTransform<InputT, OutputT> transform, InputT input) {
+  public SparkPipelineResult run(Pipeline pipeline) {
+    TestSparkPipelineOptions testSparkPipelineOptions =
+        pipeline.getOptions().as(TestSparkPipelineOptions.class);
+    //
     // if the pipeline forces execution as a streaming pipeline,
     // and the source is an adapted unbounded source (as bounded),
     // read it as unbounded source via UnboundedReadFromBoundedSource.
-    if (isForceStreaming && transform instanceof BoundedReadFromUnboundedSource)
{
-      return (OutputT) delegate.apply(new AdaptedBoundedAsUnbounded(
-          (BoundedReadFromUnboundedSource) transform), input);
-    } else {
-      // no actual override, simply counts asserting transforms in the pipeline.
-      if (transform instanceof PAssert.OneSideInputAssert
-          || transform instanceof PAssert.GroupThenAssert
-          || transform instanceof PAssert.GroupThenAssertForSingleton) {
-        expectedNumberOfAssertions += 1;
-      }
-
-      return delegate.apply(transform, input);
+    if (isForceStreaming) {
+      adaptBoundedReads(pipeline);
     }
-  }
-
-  @Override
-  public SparkPipelineResult run(Pipeline pipeline) {
-    TestSparkPipelineOptions testSparkPipelineOptions =
-        pipeline.getOptions().as(TestSparkPipelineOptions.class);
     SparkPipelineResult result = null;
-    // clear state of Aggregators, Metrics and Watermarks.
+
+    int expectedNumberOfAssertions = PAssert.countAsserts(pipeline);
+
+    // clear state of Aggregators, Metrics and Watermarks if exists.
     AggregatorsAccumulator.clear();
     SparkMetricsContainer.clear();
     GlobalWatermarkHolder.clear();
@@ -170,6 +159,13 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
     return result;
   }
 
+  @VisibleForTesting
+  void adaptBoundedReads(Pipeline pipeline) {
+    pipeline.replace(
+        PTransformMatchers.classEqualTo(BoundedReadFromUnboundedSource.class),
+        new AdaptedBoundedAsUnbounded.Factory());
+  }
+
   private static class AdaptedBoundedAsUnbounded<T> extends PTransform<PBegin, PCollection<T>>
{
     private final BoundedReadFromUnboundedSource<T> source;
 
@@ -185,6 +181,26 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
       return (PCollection<T>) input.apply(replacingTransform)
           .apply("StripIds", ParDo.of(new ValueWithRecordId.StripIdsDoFn()));
     }
-  }
 
+    static class Factory<T>
+        implements PTransformOverrideFactory<
+            PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> {
+      @Override
+      public PTransform<PBegin, PCollection<T>> getReplacementTransform(
+          BoundedReadFromUnboundedSource<T> transform) {
+        return new AdaptedBoundedAsUnbounded<>(transform);
+      }
+
+      @Override
+      public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) {
+        return p.begin();
+      }
+
+      @Override
+      public Map<PValue, ReplacementOutput> mapOutputs(
+          List<TaggedPValue> outputs, PCollection<T> newOutput) {
+        return ReplacementOutputs.singleton(outputs, newOutput);
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
b/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
index 9b39558..b60faf2 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
@@ -56,6 +56,8 @@ public class ForceStreamingTest {
         Read.from(CountingSource.unbounded()).withMaxNumRecords(-1);
     //noinspection unchecked
     pipeline.apply(boundedRead);
+    TestSparkRunner runner = TestSparkRunner.fromOptions(pipelineRule.getOptions());
+    runner.adaptBoundedReads(pipeline);
 
     UnboundedReadDetector unboundedReadDetector = new UnboundedReadDetector();
     pipeline.traverseTopologically(unboundedReadDetector);

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index 109424d..fe1d526 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -417,7 +417,7 @@ public class Pipeline {
     try {
       transforms.finishSpecifyingInput();
       transform.validate(input);
-      OutputT output = runner.apply(transform, input);
+      OutputT output = transform.expand(input);
       transforms.setOutput(output);
 
       return output;
@@ -444,7 +444,7 @@ public class Pipeline {
     LOG.debug("Replacing {} with {}", original, replacement);
     transforms.replaceNode(original, originalInput, replacement);
     try {
-      OutputT newOutput = runner.apply(replacement, originalInput);
+      OutputT newOutput = replacement.expand(originalInput);
       Map<PValue, ReplacementOutput> originalToReplacement =
           replacementFactory.mapOutputs(original.getOutputs(), newOutput);
       // Ensure the internal TransformHierarchy data structures are consistent.

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PipelineRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PipelineRunner.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PipelineRunner.java
index 8604dbc..80bb90f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PipelineRunner.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PipelineRunner.java
@@ -24,11 +24,8 @@ import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.options.GcsOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.IOChannelUtils;
 import org.apache.beam.sdk.util.InstanceBuilder;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
 
 /**
  * A {@link PipelineRunner} can execute, translate, or otherwise process a
@@ -64,15 +61,4 @@ public abstract class PipelineRunner<ResultT extends PipelineResult>
{
    * Processes the given Pipeline, returning the results.
    */
   public abstract ResultT run(Pipeline pipeline);
-
-  /**
-   * Applies a transform to the given input, returning the output.
-   *
-   * <p>The default implementation calls PTransform.apply(input), but can be overridden
-   * to customize behavior for a particular runner.
-   */
-  public <OutputT extends POutput, InputT extends PInput> OutputT apply(
-      PTransform<InputT, OutputT> transform, InputT input) {
-    return transform.expand(input);
-  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
index b57f4a9..a6fb232e 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.testing;
 
+import static com.google.common.base.Preconditions.checkState;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
@@ -32,6 +33,7 @@ import java.util.Collections;
 import java.util.Map;
 import java.util.NoSuchElementException;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.IterableCoder;
@@ -40,6 +42,7 @@ import org.apache.beam.sdk.coders.MapCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.PipelineRunner;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -1334,4 +1337,65 @@ public class PAssert {
       }
     }
   }
+
+  public static int countAsserts(Pipeline pipeline) {
+    AssertionCountingVisitor visitor = new AssertionCountingVisitor();
+    pipeline.traverseTopologically(visitor);
+    return visitor.getPAssertCount();
+  }
+
+  /**
+   * A {@link PipelineVisitor} that counts the number of total {@link PAssert PAsserts} in
a
+   * {@link Pipeline}.
+   */
+  private static class AssertionCountingVisitor extends PipelineVisitor.Defaults {
+    private int assertCount;
+    private boolean pipelineVisited;
+
+    private AssertionCountingVisitor() {
+      assertCount = 0;
+      pipelineVisited = false;
+    }
+
+    @Override
+    public CompositeBehavior enterCompositeTransform(Node node) {
+      if (node.isRootNode()) {
+        checkState(
+            !pipelineVisited,
+            "Tried to visit a pipeline with an already used %s",
+            AssertionCountingVisitor.class.getSimpleName());
+      }
+      if (!node.isRootNode()
+          && (node.getTransform() instanceof PAssert.OneSideInputAssert
+          || node.getTransform() instanceof PAssert.GroupThenAssert
+          || node.getTransform() instanceof PAssert.GroupThenAssertForSingleton)) {
+        assertCount++;
+      }
+      return CompositeBehavior.ENTER_TRANSFORM;
+    }
+
+    public void leaveCompositeTransform(Node node) {
+      if (node.isRootNode()) {
+        pipelineVisited = true;
+      }
+    }
+
+    @Override
+    public void visitPrimitiveTransform(Node node) {
+      if
+          (node.getTransform() instanceof PAssert.OneSideInputAssert
+          || node.getTransform() instanceof PAssert.GroupThenAssert
+          || node.getTransform() instanceof PAssert.GroupThenAssertForSingleton) {
+        assertCount++;
+      }
+    }
+
+    /**
+     * Gets the number of {@link PAssert PAsserts} in the pipeline.
+     */
+    int getPAssertCount() {
+      checkState(pipelineVisited);
+      return assertCount;
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/3408f604/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
index e57a254..777e1af 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.testing;
 
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
@@ -29,6 +30,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.Serializable;
+import java.util.Collections;
 import java.util.regex.Pattern;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.AtomicCoder;
@@ -37,12 +39,14 @@ import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.io.CountingInput;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
 import org.joda.time.Duration;
@@ -403,4 +407,30 @@ public class PAssertTest implements Serializable {
     fail("assertion should have failed");
     throw new RuntimeException("unreachable");
   }
+
+  @Test
+  public void countAssertsSucceeds() {
+    PCollection<Integer> create = pipeline.apply("FirstCreate", Create.of(1, 2, 3));
+
+    PAssert.that(create).containsInAnyOrder(1, 2, 3);
+    PAssert.thatSingleton(create.apply(Sum.integersGlobally())).isEqualTo(6);
+    PAssert.thatMap(pipeline.apply("CreateMap", Create.of(KV.of(1, 2))))
+        .isEqualTo(Collections.singletonMap(1, 2));
+
+    assertThat(PAssert.countAsserts(pipeline), equalTo(3));
+  }
+
+  @Test
+  public void countAssertsMultipleCallsIndependent() {
+    PCollection<Integer> create = pipeline.apply("FirstCreate", Create.of(1, 2, 3));
+
+    PAssert.that(create).containsInAnyOrder(1, 2, 3);
+    PAssert.thatSingleton(create.apply(Sum.integersGlobally())).isEqualTo(6);
+    assertThat(PAssert.countAsserts(pipeline), equalTo(2));
+
+    PAssert.thatMap(pipeline.apply("CreateMap", Create.of(KV.of(1, 2))))
+        .isEqualTo(Collections.singletonMap(1, 2));
+
+    assertThat(PAssert.countAsserts(pipeline), equalTo(3));
+  }
 }


Mime
View raw message