beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dhalp...@apache.org
Subject [2/2] incubator-beam git commit: Support for @Setup and @Teardown in DoFnTester
Date Tue, 27 Sep 2016 21:57:59 GMT
Support for @Setup and @Teardown in DoFnTester


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/bef0e9de
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/bef0e9de
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/bef0e9de

Branch: refs/heads/master
Commit: bef0e9de02be051411f20b298168e8477ed1a0da
Parents: 9009802
Author: Eugene Kirpichov <kirpichov@google.com>
Authored: Mon Sep 26 16:58:20 2016 -0700
Committer: Dan Halperin <dhalperi@google.com>
Committed: Tue Sep 27 14:57:49 2016 -0700

----------------------------------------------------------------------
 .../apache/beam/sdk/transforms/DoFnTester.java  | 120 +++--
 .../beam/sdk/transforms/DoFnTesterTest.java     | 456 +++++++++++--------
 2 files changed, 338 insertions(+), 238 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bef0e9de/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
index 0e018ba..9adb806 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
@@ -78,10 +78,10 @@ import org.joda.time.Instant;
  * @param <InputT> the type of the {@link DoFn}'s (main) input elements
  * @param <OutputT> the type of the {@link DoFn}'s (main) output elements
  */
-public class DoFnTester<InputT, OutputT> {
+public class DoFnTester<InputT, OutputT> implements AutoCloseable {
   /**
    * Returns a {@code DoFnTester} supporting unit-testing of the given
-   * {@link DoFn}.
+   * {@link DoFn}. By default, uses {@link CloningBehavior#CLONE_ONCE}.
    */
   @SuppressWarnings("unchecked")
   public static <InputT, OutputT> DoFnTester<InputT, OutputT> of(DoFn<InputT,
OutputT> fn) {
@@ -91,6 +91,8 @@ public class DoFnTester<InputT, OutputT> {
   /**
    * Returns a {@code DoFnTester} supporting unit-testing of the given
    * {@link OldDoFn}.
+   *
+   * @see #of(DoFn)
    */
   @SuppressWarnings("unchecked")
    public static <InputT, OutputT> DoFnTester<InputT, OutputT>
@@ -108,8 +110,11 @@ public class DoFnTester<InputT, OutputT> {
    * {@link DoFn} takes no side inputs.
    */
   public void setSideInputs(Map<PCollectionView<?>, Map<BoundedWindow, ?>>
sideInputs) {
+    checkState(
+        state == State.UNINITIALIZED,
+        "Can't add side inputs: DoFnTester is already initialized, in state %s",
+        state);
     this.sideInputs = sideInputs;
-    resetState();
   }
 
   /**
@@ -123,6 +128,10 @@ public class DoFnTester<InputT, OutputT> {
    * that is used.
    */
   public <T> void setSideInput(PCollectionView<T> sideInput, BoundedWindow window,
T value) {
+    checkState(
+        state == State.UNINITIALIZED,
+        "Can't add side inputs: DoFnTester is already initialized, in state %s",
+        state);
     Map<BoundedWindow, T> windowValues = (Map<BoundedWindow, T>) sideInputs.get(sideInput);
     if (windowValues == null) {
       windowValues = new HashMap<>();
@@ -132,10 +141,24 @@ public class DoFnTester<InputT, OutputT> {
   }
 
   /**
-   * Whether or not a {@link DoFnTester} should clone the {@link DoFn} under test.
+   * When a {@link DoFnTester} should clone the {@link DoFn} under test and how it should
manage
+   * the lifecycle of the {@link DoFn}.
    */
   public enum CloningBehavior {
-    CLONE,
+    /**
+     * Clone the {@link DoFn} and call {@link DoFn.Setup} every time a bundle starts; call
{@link
+     * DoFn.Teardown} every time a bundle finishes.
+     */
+    CLONE_PER_BUNDLE,
+    /**
+     * Clone the {@link DoFn} and call {@link DoFn.Setup} on the first access; call {@link
+     * DoFn.Teardown} only explicitly.
+     */
+    CLONE_ONCE,
+    /**
+     * Do not clone the {@link DoFn}; call {@link DoFn.Setup} on the first access; call {@link
+     * DoFn.Teardown} only explicitly.
+     */
     DO_NOT_CLONE
   }
 
@@ -143,6 +166,7 @@ public class DoFnTester<InputT, OutputT> {
    * Instruct this {@link DoFnTester} whether or not to clone the {@link DoFn} under test.
    */
   public void setCloningBehavior(CloningBehavior newValue) {
+    checkState(state == State.UNINITIALIZED, "Wrong state: %s", state);
     this.cloningBehavior = newValue;
   }
 
@@ -187,11 +211,17 @@ public class DoFnTester<InputT, OutputT> {
   /**
    * Calls the {@link DoFn.StartBundle} method on the {@link DoFn} under test.
    *
-   * <p>If needed, first creates a fresh instance of the {@link DoFn} under test.
+   * <p>If needed, first creates a fresh instance of the {@link DoFn} under test and
calls
+   * {@link DoFn.Setup}.
    */
   public void startBundle() throws Exception {
-    resetState();
-    initializeState();
+    checkState(
+        state == State.UNINITIALIZED || state == State.BUNDLE_FINISHED,
+        "Wrong state during startBundle: %s",
+        state);
+    if (state == State.UNINITIALIZED) {
+      initializeState();
+    }
     TestContext<InputT, OutputT> context = createContext(fn);
     context.setupDelegateAggregators();
     try {
@@ -199,7 +229,7 @@ public class DoFnTester<InputT, OutputT> {
     } catch (UserCodeException e) {
       unwrapUserCodeException(e);
     }
-    state = State.STARTED;
+    state = State.BUNDLE_STARTED;
   }
 
   private static void unwrapUserCodeException(UserCodeException e) throws Exception {
@@ -236,15 +266,10 @@ public class DoFnTester<InputT, OutputT> {
    * already been called.
    *
    * <p>If the input timestamp is {@literal null}, the minimum timestamp will be used.
-   *
-   * @throws IllegalStateException if the {@code OldDoFn} under test has already
-   * been finished
    */
   public void processTimestampedElement(TimestampedValue<InputT> element) throws Exception
{
     checkNotNull(element, "Timestamped element cannot be null");
-    checkState(state != State.FINISHED, "finishBundle() has already been called");
-
-    if (state == State.UNSTARTED) {
+    if (state != State.BUNDLE_STARTED) {
       startBundle();
     }
     try {
@@ -257,25 +282,30 @@ public class DoFnTester<InputT, OutputT> {
   /**
    * Calls the {@link DoFn.FinishBundle} method of the {@link DoFn} under test.
    *
-   * <p>Will call {@link #startBundle} automatically, if it hasn't
-   * already been called.
+   * <p>If {@link #setCloningBehavior} was called with {@link CloningBehavior#CLONE_PER_BUNDLE},
+   * then also calls {@link DoFn.Teardown} on the {@link DoFn}, and it will be cloned and
+   * {@link DoFn.Setup} again when processing the next bundle.
    *
-   * @throws IllegalStateException if the {@link DoFn} under test has already
-   * been finished
+   * @throws IllegalStateException if {@link DoFn.FinishBundle} has already been called
+   * for this bundle.
    */
   public void finishBundle() throws Exception {
-    if (state == State.FINISHED) {
-      throw new IllegalStateException("finishBundle() has already been called");
-    }
-    if (state == State.UNSTARTED) {
-      startBundle();
-    }
+    checkState(
+        state == State.BUNDLE_STARTED,
+        "Must be inside bundle to call finishBundle, but was: %s",
+        state);
     try {
       fn.finishBundle(createContext(fn));
     } catch (UserCodeException e) {
       unwrapUserCodeException(e);
     }
-    state = State.FINISHED;
+    if (cloningBehavior == CloningBehavior.CLONE_PER_BUNDLE) {
+      fn.teardown();
+      fn = null;
+      state = State.UNINITIALIZED;
+    } else {
+      state = State.BUNDLE_FINISHED;
+    }
   }
 
   /**
@@ -695,13 +725,26 @@ public class DoFnTester<InputT, OutputT> {
     }
   }
 
+  @Override
+  public void close() throws Exception {
+    if (state == State.BUNDLE_STARTED) {
+      finishBundle();
+    }
+    if (state == State.BUNDLE_FINISHED) {
+      fn.teardown();
+      fn = null;
+    }
+    state = State.TORN_DOWN;
+  }
+
   /////////////////////////////////////////////////////////////////////////////
 
   /** The possible states of processing a {@link DoFn}. */
-  enum State {
-    UNSTARTED,
-    STARTED,
-    FINISHED
+  private enum State {
+    UNINITIALIZED,
+    BUNDLE_STARTED,
+    BUNDLE_FINISHED,
+    TORN_DOWN
   }
 
   private final PipelineOptions options = PipelineOptionsFactory.create();
@@ -714,7 +757,7 @@ public class DoFnTester<InputT, OutputT> {
    *
    * <p>Worker-side {@link DoFn DoFns} may not be serializable, and are not required
to be.
    */
-  private CloningBehavior cloningBehavior = CloningBehavior.CLONE;
+  private CloningBehavior cloningBehavior = CloningBehavior.CLONE_ONCE;
 
   /** The side input values to provide to the {@link DoFn} under test. */
   private Map<PCollectionView<?>, Map<BoundedWindow, ?>> sideInputs =
@@ -732,22 +775,16 @@ public class DoFnTester<InputT, OutputT> {
   private Map<TupleTag<?>, List<WindowedValue<?>>> outputs;
 
   /** The state of processing of the {@link DoFn} under test. */
-  private State state;
+  private State state = State.UNINITIALIZED;
 
   private DoFnTester(OldDoFn<InputT, OutputT> origFn) {
     this.origFn = origFn;
-    resetState();
-  }
-
-  private void resetState() {
-    fn = null;
-    outputs = null;
-    accumulators = null;
-    state = State.UNSTARTED;
   }
 
   @SuppressWarnings("unchecked")
-  private void initializeState() {
+  private void initializeState() throws Exception {
+    checkState(state == State.UNINITIALIZED, "Already initialized");
+    checkState(fn == null, "Uninitialized but fn != null");
     if (cloningBehavior.equals(CloningBehavior.DO_NOT_CLONE)) {
       fn = origFn;
     } else {
@@ -756,6 +793,7 @@ public class DoFnTester<InputT, OutputT> {
               SerializableUtils.serializeToByteArray(origFn),
               origFn.toString());
     }
+    fn.setup();
     outputs = new HashMap<>();
     accumulators = new HashMap<>();
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/bef0e9de/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
index 3ed30fd..f208488 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTesterTest.java
@@ -17,15 +17,17 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import static com.google.common.base.Preconditions.checkState;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItems;
-import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -51,122 +53,180 @@ public class DoFnTesterTest {
 
   @Test
   public void processElement() throws Exception {
-    CounterDoFn counterDoFn = new CounterDoFn();
-    DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn);
-
-    tester.processElement(1L);
-
-    List<String> take = tester.takeOutputElements();
+    for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
+      try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
+        tester.setCloningBehavior(cloning);
+        tester.processElement(1L);
 
-    assertThat(take, hasItems("1"));
+        List<String> take = tester.takeOutputElements();
 
-    // Following takeOutputElements(), neither takeOutputElements()
-    // nor peekOutputElements() return anything.
-    assertTrue(tester.takeOutputElements().isEmpty());
-    assertTrue(tester.peekOutputElements().isEmpty());
+        assertThat(take, hasItems("1"));
 
-    // processElement() caused startBundle() to be called, but finishBundle() was never called.
-    CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn;
-    assertTrue(deserializedDoFn.wasStartBundleCalled());
-    assertFalse(deserializedDoFn.wasFinishBundleCalled());
+        // Following takeOutputElements(), neither takeOutputElements()
+        // nor peekOutputElements() return anything.
+        assertTrue(tester.takeOutputElements().isEmpty());
+        assertTrue(tester.peekOutputElements().isEmpty());
+      }
+    }
   }
 
   @Test
   public void processElementsWithPeeks() throws Exception {
-    CounterDoFn counterDoFn = new CounterDoFn();
-    DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn);
-
-    // Explicitly call startBundle().
-    tester.startBundle();
-
-    // verify startBundle() was called but not finishBundle().
-    CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn;
-    assertTrue(deserializedDoFn.wasStartBundleCalled());
-    assertFalse(deserializedDoFn.wasFinishBundleCalled());
-
-    // process a couple of elements.
-    tester.processElement(1L);
-    tester.processElement(2L);
-
-    // peek the first 2 outputs.
-    List<String> peek = tester.peekOutputElements();
-    assertThat(peek, hasItems("1", "2"));
-
-    // process a couple more.
-    tester.processElement(3L);
-    tester.processElement(4L);
-
-    // peek all the outputs so far.
-    peek = tester.peekOutputElements();
-    assertThat(peek, hasItems("1", "2", "3", "4"));
-    // take the outputs.
-    List<String> take = tester.takeOutputElements();
-    assertThat(take, hasItems("1", "2", "3", "4"));
-
-    // Following takeOutputElements(), neither takeOutputElements()
-    // nor peekOutputElements() return anything.
-    assertTrue(tester.peekOutputElements().isEmpty());
-    assertTrue(tester.takeOutputElements().isEmpty());
-
-    // verify finishBundle() hasn't been called yet.
-    assertTrue(deserializedDoFn.wasStartBundleCalled());
-    assertFalse(deserializedDoFn.wasFinishBundleCalled());
-
-    // process a couple more.
-    tester.processElement(5L);
-    tester.processElement(6L);
-
-    // peek and take now have only the 2 last outputs.
-    peek = tester.peekOutputElements();
-    assertThat(peek, hasItems("5", "6"));
-    take = tester.takeOutputElements();
-    assertThat(take, hasItems("5", "6"));
-
-    tester.finishBundle();
-
-    // verify finishBundle() was called.
-    assertTrue(deserializedDoFn.wasStartBundleCalled());
-    assertTrue(deserializedDoFn.wasFinishBundleCalled());
+    for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
+      try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
+        tester.setCloningBehavior(cloning);
+        // Explicitly call startBundle().
+        tester.startBundle();
+
+        // process a couple of elements.
+        tester.processElement(1L);
+        tester.processElement(2L);
+
+        // peek the first 2 outputs.
+        List<String> peek = tester.peekOutputElements();
+        assertThat(peek, hasItems("1", "2"));
+
+        // process a couple more.
+        tester.processElement(3L);
+        tester.processElement(4L);
+
+        // peek all the outputs so far.
+        peek = tester.peekOutputElements();
+        assertThat(peek, hasItems("1", "2", "3", "4"));
+        // take the outputs.
+        List<String> take = tester.takeOutputElements();
+        assertThat(take, hasItems("1", "2", "3", "4"));
+
+        // Following takeOutputElements(), neither takeOutputElements()
+        // nor peekOutputElements() return anything.
+        assertTrue(tester.peekOutputElements().isEmpty());
+        assertTrue(tester.takeOutputElements().isEmpty());
+
+        // process a couple more.
+        tester.processElement(5L);
+        tester.processElement(6L);
+
+        // peek and take now have only the 2 last outputs.
+        peek = tester.peekOutputElements();
+        assertThat(peek, hasItems("5", "6"));
+        take = tester.takeOutputElements();
+        assertThat(take, hasItems("5", "6"));
+
+        tester.finishBundle();
+      }
+    }
   }
 
   @Test
-  public void processElementAfterFinish() throws Exception {
-    DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn());
-    tester.finishBundle();
+  public void processBundle() throws Exception {
+    for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
+      try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
+        tester.setCloningBehavior(cloning);
+        // processBundle() returns all the output like takeOutputElements().
+        assertThat(tester.processBundle(1L, 2L, 3L, 4L), hasItems("1", "2", "3", "4"));
+
+        // peek now returns nothing.
+        assertTrue(tester.peekOutputElements().isEmpty());
+      }
+    }
+  }
 
-    thrown.expect(IllegalStateException.class);
-    thrown.expectMessage("finishBundle() has already been called");
-    tester.processElement(1L);
+  @Test
+  public void processMultipleBundles() throws Exception {
+    for (DoFnTester.CloningBehavior cloning : DoFnTester.CloningBehavior.values()) {
+      try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
+        tester.setCloningBehavior(cloning);
+        // processBundle() returns all the output like takeOutputElements().
+        assertThat(tester.processBundle(1L, 2L, 3L, 4L), hasItems("1", "2", "3", "4"));
+        assertThat(tester.processBundle(5L, 6L, 7L), hasItems("5", "6", "7"));
+        assertThat(tester.processBundle(8L, 9L), hasItems("8", "9"));
+
+        // peek now returns nothing.
+        assertTrue(tester.peekOutputElements().isEmpty());
+      }
+    }
   }
 
   @Test
-  public void processBatch() throws Exception {
-    CounterDoFn counterDoFn = new CounterDoFn();
-    DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn);
+  public void doNotClone() throws Exception {
+    final AtomicInteger numSetupCalls = new AtomicInteger();
+    final AtomicInteger numTeardownCalls = new AtomicInteger();
+    DoFn<Long, String> fn =
+        new DoFn<Long, String>() {
+          @ProcessElement
+          public void process(ProcessContext context) {}
+
+          @Setup
+          public void setup() {
+            numSetupCalls.addAndGet(1);
+          }
+
+          @Teardown
+          public void teardown() {
+            numTeardownCalls.addAndGet(1);
+          }
+        };
+
+    try (DoFnTester<Long, String> tester = DoFnTester.of(fn)) {
+      tester.setCloningBehavior(DoFnTester.CloningBehavior.DO_NOT_CLONE);
+
+      tester.processBundle(1L, 2L, 3L);
+      tester.processBundle(4L, 5L);
+      tester.processBundle(6L);
+    }
+    assertEquals(1, numSetupCalls.get());
+    assertEquals(1, numTeardownCalls.get());
+  }
 
-    // processBundle() returns all the output like takeOutputElements().
-    List<String> take = tester.processBundle(1L, 2L, 3L, 4L);
+  private static class CountBundleCallsFn extends DoFn<Long, String> {
+    private int numStartBundleCalls = 0;
+    private int numFinishBundleCalls = 0;
 
-    assertThat(take, hasItems("1", "2", "3", "4"));
+    @ProcessElement
+    public void process(ProcessContext context) {
+      context.output(numStartBundleCalls + "/" + numFinishBundleCalls);
+    }
 
-    // peek now returns nothing.
-    assertTrue(tester.peekOutputElements().isEmpty());
+    @StartBundle
+    public void startBundle(Context context) {
+      ++numStartBundleCalls;
+    }
 
-    // verify startBundle() and finishBundle() were both called.
-    CounterDoFn deserializedDoFn = (CounterDoFn) tester.fn;
-    assertTrue(deserializedDoFn.wasStartBundleCalled());
-    assertTrue(deserializedDoFn.wasFinishBundleCalled());
+    @FinishBundle
+    public void finishBundle(Context context) {
+      ++numFinishBundleCalls;
+    }
   }
 
   @Test
-  public void processTimestampedElement() throws Exception {
-    DoFn<Long, TimestampedValue<Long>> reifyTimestamps = new ReifyTimestamps();
+  public void cloneOnce() throws Exception {
+    try (DoFnTester<Long, String> tester = DoFnTester.of(new CountBundleCallsFn()))
{
+      tester.setCloningBehavior(DoFnTester.CloningBehavior.CLONE_ONCE);
+
+      assertThat(tester.processBundle(1L, 2L, 3L), contains("1/0", "1/0", "1/0"));
+      assertThat(tester.processBundle(4L, 5L), contains("2/1", "2/1"));
+      assertThat(tester.processBundle(6L), contains("3/2"));
+    }
+  }
+
+  @Test
+  public void clonePerBundle() throws Exception {
+    try (DoFnTester<Long, String> tester = DoFnTester.of(new CountBundleCallsFn()))
{
+      tester.setCloningBehavior(DoFnTester.CloningBehavior.CLONE_PER_BUNDLE);
 
-    DoFnTester<Long, TimestampedValue<Long>> tester = DoFnTester.of(reifyTimestamps);
+      assertThat(tester.processBundle(1L, 2L, 3L), contains("1/0", "1/0", "1/0"));
+      assertThat(tester.processBundle(4L, 5L), contains("1/0", "1/0"));
+      assertThat(tester.processBundle(6L), contains("1/0"));
+    }
+  }
 
-    TimestampedValue<Long> input = TimestampedValue.of(1L, new Instant(100));
-    tester.processTimestampedElement(input);
-    assertThat(tester.takeOutputElements(), contains(input));
+  @Test
+  public void processTimestampedElement() throws Exception {
+    try (DoFnTester<Long, TimestampedValue<Long>> tester = DoFnTester.of(new
ReifyTimestamps())) {
+      TimestampedValue<Long> input = TimestampedValue.of(1L, new Instant(100));
+      tester.processTimestampedElement(input);
+      assertThat(tester.takeOutputElements(), contains(input));
+    }
   }
 
   static class ReifyTimestamps extends DoFn<Long, TimestampedValue<Long>> {
@@ -178,86 +238,83 @@ public class DoFnTesterTest {
 
   @Test
   public void processElementWithOutputTimestamp() throws Exception {
-    CounterDoFn counterDoFn = new CounterDoFn();
-    DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn);
-
-    tester.processElement(1L);
-    tester.processElement(2L);
-
-    List<TimestampedValue<String>> peek = tester.peekOutputElementsWithTimestamp();
-    TimestampedValue<String> one = TimestampedValue.of("1", new Instant(1000L));
-    TimestampedValue<String> two = TimestampedValue.of("2", new Instant(2000L));
-    assertThat(peek, hasItems(one, two));
-
-    tester.processElement(3L);
-    tester.processElement(4L);
-
-    TimestampedValue<String> three = TimestampedValue.of("3", new Instant(3000L));
-    TimestampedValue<String> four = TimestampedValue.of("4", new Instant(4000L));
-    peek = tester.peekOutputElementsWithTimestamp();
-    assertThat(peek, hasItems(one, two, three, four));
-    List<TimestampedValue<String>> take = tester.takeOutputElementsWithTimestamp();
-    assertThat(take, hasItems(one, two, three, four));
-
-    // Following takeOutputElementsWithTimestamp(), neither takeOutputElementsWithTimestamp()
-    // nor peekOutputElementsWithTimestamp() return anything.
-    assertTrue(tester.takeOutputElementsWithTimestamp().isEmpty());
-    assertTrue(tester.peekOutputElementsWithTimestamp().isEmpty());
-
-    // peekOutputElements() and takeOutputElements() also return nothing.
-    assertTrue(tester.peekOutputElements().isEmpty());
-    assertTrue(tester.takeOutputElements().isEmpty());
+    try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
+      tester.processElement(1L);
+      tester.processElement(2L);
+
+      List<TimestampedValue<String>> peek = tester.peekOutputElementsWithTimestamp();
+      TimestampedValue<String> one = TimestampedValue.of("1", new Instant(1000L));
+      TimestampedValue<String> two = TimestampedValue.of("2", new Instant(2000L));
+      assertThat(peek, hasItems(one, two));
+
+      tester.processElement(3L);
+      tester.processElement(4L);
+
+      TimestampedValue<String> three = TimestampedValue.of("3", new Instant(3000L));
+      TimestampedValue<String> four = TimestampedValue.of("4", new Instant(4000L));
+      peek = tester.peekOutputElementsWithTimestamp();
+      assertThat(peek, hasItems(one, two, three, four));
+      List<TimestampedValue<String>> take = tester.takeOutputElementsWithTimestamp();
+      assertThat(take, hasItems(one, two, three, four));
+
+      // Following takeOutputElementsWithTimestamp(), neither takeOutputElementsWithTimestamp()
+      // nor peekOutputElementsWithTimestamp() return anything.
+      assertTrue(tester.takeOutputElementsWithTimestamp().isEmpty());
+      assertTrue(tester.peekOutputElementsWithTimestamp().isEmpty());
+
+      // peekOutputElements() and takeOutputElements() also return nothing.
+      assertTrue(tester.peekOutputElements().isEmpty());
+      assertTrue(tester.takeOutputElements().isEmpty());
+    }
   }
 
   @Test
   public void getAggregatorValuesShouldGetValueOfCounter() throws Exception {
     CounterDoFn counterDoFn = new CounterDoFn();
-    DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn);
-    tester.processBundle(1L, 2L, 4L, 8L);
-
-    Long aggregatorVal = tester.getAggregatorValue(counterDoFn.agg);
-
-    assertThat(aggregatorVal, equalTo(15L));
+    try (DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn)) {
+      tester.processBundle(1L, 2L, 4L, 8L);
+      assertThat(tester.getAggregatorValue(counterDoFn.agg), equalTo(15L));
+    }
   }
 
   @Test
   public void getAggregatorValuesWithEmptyCounterShouldSucceed() throws Exception {
     CounterDoFn counterDoFn = new CounterDoFn();
-    DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn);
-    tester.processBundle();
-    Long aggregatorVal = tester.getAggregatorValue(counterDoFn.agg);
-    // empty bundle
-    assertThat(aggregatorVal, equalTo(0L));
+    try (DoFnTester<Long, String> tester = DoFnTester.of(counterDoFn)) {
+      tester.processBundle();
+      // empty bundle
+      assertThat(tester.getAggregatorValue(counterDoFn.agg), equalTo(0L));
+    }
   }
 
   @Test
   public void getAggregatorValuesInStartFinishBundleShouldGetValues() throws Exception {
-    CounterDoFn fn = new CounterDoFn(1L, 2L);
-    DoFnTester<Long, String> tester = DoFnTester.of(fn);
-    tester.processBundle(0L, 0L);
+    CounterDoFn fn = new CounterDoFn();
+    try (DoFnTester<Long, String> tester = DoFnTester.of(fn)) {
+      tester.processBundle(1L, 2L, 3L, 4L);
 
-    Long aggValue = tester.getAggregatorValue(fn.agg);
-    assertThat(aggValue, equalTo(1L + 2L));
+      assertThat(tester.getAggregatorValue(fn.startBundleCalls), equalTo(1L));
+      assertThat(tester.getAggregatorValue(fn.finishBundleCalls), equalTo(1L));
+    }
   }
 
   @Test
   public void peekValuesInWindow() throws Exception {
-    CounterDoFn fn = new CounterDoFn(1L, 2L);
-    DoFnTester<Long, String> tester = DoFnTester.of(fn);
-
-    tester.startBundle();
-    tester.processElement(1L);
-    tester.processElement(2L);
-    tester.finishBundle();
-
-    assertThat(
-        tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE),
-        containsInAnyOrder(
-            TimestampedValue.of("1", new Instant(1000L)),
-            TimestampedValue.of("2", new Instant(2000L))));
-    assertThat(
-        tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), new Instant(10L))),
-        Matchers.<TimestampedValue<String>>emptyIterable());
+    try (DoFnTester<Long, String> tester = DoFnTester.of(new CounterDoFn())) {
+      tester.startBundle();
+      tester.processElement(1L);
+      tester.processElement(2L);
+      tester.finishBundle();
+
+      assertThat(
+          tester.peekOutputElementsInWindow(GlobalWindow.INSTANCE),
+          containsInAnyOrder(
+              TimestampedValue.of("1", new Instant(1000L)),
+              TimestampedValue.of("2", new Instant(2000L))));
+      assertThat(
+          tester.peekOutputElementsInWindow(new IntervalWindow(new Instant(0L), new Instant(10L))),
+          Matchers.<TimestampedValue<String>>emptyIterable());
+    }
   }
 
   @Test
@@ -265,15 +322,14 @@ public class DoFnTesterTest {
     final PCollectionView<Integer> value =
         PCollectionViews.singletonView(
             TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, VarIntCoder.of());
-    OldDoFn<Integer, Integer> fn = new SideInputDoFn(value);
-
-    DoFnTester<Integer, Integer> tester = DoFnTester.of(fn);
 
-    tester.processElement(1);
-    tester.processElement(2);
-    tester.processElement(4);
-    tester.processElement(8);
-    assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0));
+    try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new SideInputDoFn(value)))
{
+      tester.processElement(1);
+      tester.processElement(2);
+      tester.processElement(4);
+      tester.processElement(8);
+      assertThat(tester.peekOutputElements(), containsInAnyOrder(0, 0, 0, 0));
+    }
   }
 
   @Test
@@ -281,17 +337,17 @@ public class DoFnTesterTest {
     final PCollectionView<Integer> value =
         PCollectionViews.singletonView(
             TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0, VarIntCoder.of());
-    OldDoFn<Integer, Integer> fn = new SideInputDoFn(value);
 
-    DoFnTester<Integer, Integer> tester = DoFnTester.of(fn);
-    tester.setSideInput(value, GlobalWindow.INSTANCE, -2);
-    tester.processElement(16);
-    tester.processElement(32);
-    tester.processElement(64);
-    tester.processElement(128);
-    tester.finishBundle();
+    try (DoFnTester<Integer, Integer> tester = DoFnTester.of(new SideInputDoFn(value)))
{
+      tester.setSideInput(value, GlobalWindow.INSTANCE, -2);
+      tester.processElement(16);
+      tester.processElement(32);
+      tester.processElement(64);
+      tester.processElement(128);
+      tester.finishBundle();
 
-    assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, -2));
+      assertThat(tester.peekOutputElements(), containsInAnyOrder(-2, -2, -2, -2));
+    }
   }
 
   private static class SideInputDoFn extends OldDoFn<Integer, Integer> {
@@ -308,50 +364,56 @@ public class DoFnTesterTest {
   }
 
   /**
-   * An {@link OldDoFn} that adds values to an aggregator and converts input to String in
+   * A {@link DoFn} that adds values to an aggregator and converts input to String in
    * {@link OldDoFn#processElement).
    */
-  private static class CounterDoFn extends OldDoFn<Long, String> {
+  private static class CounterDoFn extends DoFn<Long, String> {
     Aggregator<Long, Long> agg = createAggregator("ctr", new Sum.SumLongFn());
-    private final long startBundleVal;
-    private final long finishBundleVal;
-    private boolean startBundleCalled;
-    private boolean finishBundleCalled;
-
-    public CounterDoFn() {
-      this(0L, 0L);
+    Aggregator<Long, Long> startBundleCalls =
+        createAggregator("startBundleCalls", new Sum.SumLongFn());
+    Aggregator<Long, Long> finishBundleCalls =
+        createAggregator("finishBundleCalls", new Sum.SumLongFn());
+
+    private enum LifecycleState {
+      UNINITIALIZED,
+      SET_UP,
+      INSIDE_BUNDLE,
+      TORN_DOWN
     }
+    private LifecycleState state = LifecycleState.UNINITIALIZED;
 
-    public CounterDoFn(long start, long finish) {
-      this.startBundleVal = start;
-      this.finishBundleVal = finish;
+    @Setup
+    public void setup() {
+      checkState(state == LifecycleState.UNINITIALIZED, "Wrong state: %s", state);
+      state = LifecycleState.SET_UP;
     }
 
-    @Override
+    @StartBundle
     public void startBundle(Context c) {
-      agg.addValue(startBundleVal);
-      startBundleCalled = true;
+      checkState(state == LifecycleState.SET_UP, "Wrong state: %s", state);
+      state = LifecycleState.INSIDE_BUNDLE;
+      startBundleCalls.addValue(1L);
     }
 
-    @Override
+    @ProcessElement
     public void processElement(ProcessContext c) throws Exception {
+      checkState(state == LifecycleState.INSIDE_BUNDLE, "Wrong state: %s", state);
       agg.addValue(c.element());
       Instant instant = new Instant(1000L * c.element());
       c.outputWithTimestamp(c.element().toString(), instant);
     }
 
-    @Override
+    @FinishBundle
     public void finishBundle(Context c) {
-      agg.addValue(finishBundleVal);
-      finishBundleCalled = true;
-    }
-
-    boolean wasStartBundleCalled() {
-      return startBundleCalled;
+      checkState(state == LifecycleState.INSIDE_BUNDLE, "Wrong state: %s", state);
+      state = LifecycleState.SET_UP;
+      finishBundleCalls.addValue(1L);
     }
 
-    boolean wasFinishBundleCalled() {
-      return finishBundleCalled;
+    @Teardown
+    public void teardown() {
+      checkState(state == LifecycleState.SET_UP, "Wrong state: %s", state);
+      state = LifecycleState.TORN_DOWN;
     }
   }
 }


Mime
View raw message