beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dhalp...@apache.org
Subject [2/5] incubator-beam git commit: Add DoFn @Setup and @Teardown
Date Mon, 15 Aug 2016 21:17:07 GMT
Add DoFn @Setup and @Teardown

Methods annotated with these annotations are used to perform expensive
setup work and clean up a DoFn after another method throws an exception
or the DoFn is discarded.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/12abb1b0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/12abb1b0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/12abb1b0

Branch: refs/heads/master
Commit: 12abb1b02246b8d36021c7b1a970daf1b64ba4b9
Parents: cf0bf3b
Author: Thomas Groh <tgroh@google.com>
Authored: Thu Jul 14 14:51:02 2016 -0700
Committer: Dan Halperin <dhalperi@google.com>
Committed: Mon Aug 15 14:16:54 2016 -0700

----------------------------------------------------------------------
 .../runners/direct/DoFnLifecycleManager.java    |  38 +-
 ...ecycleManagerRemovingTransformEvaluator.java |  39 +-
 .../runners/direct/DoFnLifecycleManagers.java   |  45 ++
 .../direct/ParDoMultiEvaluatorFactory.java      |   4 +-
 .../direct/ParDoSingleEvaluatorFactory.java     |   4 +-
 .../direct/DoFnLifecycleManagerTest.java        |  49 +++
 .../direct/DoFnLifecycleManagersTest.java       | 142 +++++++
 .../functions/FlinkDoFnFunction.java            |  12 +-
 .../functions/FlinkMultiOutputDoFnFunction.java |  31 +-
 .../streaming/FlinkAbstractParDoWrapper.java    |   2 +
 .../FlinkGroupAlsoByWindowWrapper.java          |   2 +
 .../runners/spark/translation/DoFnFunction.java |  23 +-
 .../spark/translation/MultiDoFnFunction.java    |   1 +
 .../spark/translation/SparkProcessContext.java  |  17 +
 .../org/apache/beam/sdk/transforms/DoFn.java    |  31 +-
 .../beam/sdk/transforms/DoFnReflector.java      |  70 +++-
 .../org/apache/beam/sdk/transforms/OldDoFn.java |  25 ++
 .../org/apache/beam/sdk/transforms/ParDo.java   |  15 +-
 .../beam/sdk/transforms/DoFnReflectorTest.java  |  65 +++
 .../apache/beam/sdk/transforms/ParDoTest.java   | 420 ++++++++++++++++++-
 20 files changed, 970 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
index 2783657..3f4f2c6 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManager.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.runners.direct;
 
+import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.util.SerializableUtils;
@@ -26,6 +27,13 @@ import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+
 /**
  * Manages {@link DoFn} setup, teardown, and serialization.
  *
@@ -35,6 +43,8 @@ import com.google.common.cache.LoadingCache;
  * {@link DoFn DoFns}.
  */
 class DoFnLifecycleManager {
+  private static final Logger LOG = LoggerFactory.getLogger(DoFnLifecycleManager.class);
+
   public static DoFnLifecycleManager of(OldDoFn<?, ?> original) {
     return new DoFnLifecycleManager(original);
   }
@@ -52,14 +62,30 @@ class DoFnLifecycleManager {
 
   public void remove() throws Exception {
     Thread currentThread = Thread.currentThread();
-    outstanding.invalidate(currentThread);
+    OldDoFn<?, ?> fn = outstanding.asMap().remove(currentThread);
+    fn.teardown();
   }
 
   /**
-   * Remove all {@link DoFn DoFns} from this {@link DoFnLifecycleManager}.
+   * Remove all {@link DoFn DoFns} from this {@link DoFnLifecycleManager}. Returns all exceptions
+   * that were thrown while calling the remove methods.
+   *
+   * <p>If the returned Collection is nonempty, an exception was thrown from at least one
+   * {@link DoFn#teardown()} method, and the {@link PipelineRunner} should throw an exception.
    */
-  public void removeAll() throws Exception {
-    outstanding.invalidateAll();
+  public Collection<Exception> removeAll() throws Exception {
+    Iterator<OldDoFn<?, ?>> fns = outstanding.asMap().values().iterator();
+    Collection<Exception> thrown = new ArrayList<>();
+    while (fns.hasNext()) {
+      OldDoFn<?, ?> fn = fns.next();
+      fns.remove();
+      try {
+        fn.teardown();
+      } catch (Exception e) {
+        thrown.add(e);
+      }
+    }
+    return thrown;
   }
 
   private class DeserializingCacheLoader extends CacheLoader<Thread, OldDoFn<?, ?>> {
@@ -71,8 +97,10 @@ class DoFnLifecycleManager {
 
     @Override
     public OldDoFn<?, ?> load(Thread key) throws Exception {
-      return (OldDoFn<?, ?>) SerializableUtils.deserializeFromByteArray(original,
+      OldDoFn<?, ?> fn = (OldDoFn<?, ?>) SerializableUtils.deserializeFromByteArray(original,
           "DoFn Copy in thread " + key.getName());
+      fn.setup();
+      return fn;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
index f3d1d4f..523273c 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagerRemovingTransformEvaluator.java
@@ -34,14 +34,14 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
   private final DoFnLifecycleManager lifecycleManager;
 
   public static <InputT> TransformEvaluator<InputT> wrapping(
-      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
-    return new DoFnLifecycleManagerRemovingTransformEvaluator<>(underlying, threadLocal);
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager lifecycleManager) {
+    return new DoFnLifecycleManagerRemovingTransformEvaluator<>(underlying, lifecycleManager);
   }
 
   private DoFnLifecycleManagerRemovingTransformEvaluator(
-      TransformEvaluator<InputT> underlying, DoFnLifecycleManager threadLocal) {
+      TransformEvaluator<InputT> underlying, DoFnLifecycleManager lifecycleManager) {
     this.underlying = underlying;
-    this.lifecycleManager = threadLocal;
+    this.lifecycleManager = lifecycleManager;
   }
 
   @Override
@@ -49,14 +49,7 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
     try {
       underlying.processElement(element);
     } catch (Exception e) {
-      try {
-        lifecycleManager.remove();
-      } catch (Exception removalException) {
-        LOG.error(
-            "Exception encountered while cleaning up after processing an element",
-            removalException);
-        e.addSuppressed(removalException);
-      }
+      onException(e, "Exception encountered while cleaning up after processing an element");
       throw e;
     }
   }
@@ -66,15 +59,21 @@ class DoFnLifecycleManagerRemovingTransformEvaluator<InputT> implements Transfor
     try {
       return underlying.finishBundle();
     } catch (Exception e) {
-      try {
-        lifecycleManager.remove();
-      } catch (Exception removalException) {
-        LOG.error(
-            "Exception encountered while cleaning up after finishing a bundle",
-            removalException);
-        e.addSuppressed(removalException);
-      }
+      onException(e, "Exception encountered while cleaning up after finishing a bundle");
       throw e;
     }
   }
+
+  private void onException(Exception e, String msg) {
+    try {
+      lifecycleManager.remove();
+    } catch (Exception removalException) {
+      if (removalException instanceof InterruptedException) {
+        Thread.currentThread().interrupt();
+      }
+      LOG.error(msg, removalException);
+      e.addSuppressed(removalException);
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java
new file mode 100644
index 0000000..6a1dd8f
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DoFnLifecycleManagers.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+/**
+ * Utility methods for interacting with {@link DoFnLifecycleManager DoFnLifecycleManagers}.
+ */
+class DoFnLifecycleManagers {
+  private DoFnLifecycleManagers() {
+    /* Do not instantiate */
+  }
+
+  static void removeAllFromManagers(Iterable<DoFnLifecycleManager> managers) throws Exception {
+    Collection<Exception> thrown = new ArrayList<>();
+    for (DoFnLifecycleManager manager : managers) {
+      thrown.addAll(manager.removeAll());
+    }
+    if (!thrown.isEmpty()) {
+      Exception overallException = new Exception("Exceptions thrown while tearing down DoFns");
+      for (Exception e : thrown) {
+        overallException.addSuppressed(e);
+      }
+      throw overallException;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
index f2455e1..2d05e68 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiEvaluatorFactory.java
@@ -69,9 +69,7 @@ class ParDoMultiEvaluatorFactory implements TransformEvaluatorFactory {
 
   @Override
   public void cleanup() throws Exception {
-    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
-      lifecycleManager.removeAll();
-    }
+    DoFnLifecycleManagers.removeAllFromManagers(fnClones.asMap().values());
   }
 
   private <InT, OuT> TransformEvaluator<InT> createMultiEvaluator(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
index a0fbd1d..97cbfa7 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleEvaluatorFactory.java
@@ -70,9 +70,7 @@ class ParDoSingleEvaluatorFactory implements TransformEvaluatorFactory {
 
   @Override
   public void cleanup() throws Exception {
-    for (DoFnLifecycleManager lifecycleManager : fnClones.asMap().values()) {
-      lifecycleManager.removeAll();
-    }
+    DoFnLifecycleManagers.removeAllFromManagers(fnClones.asMap().values());
   }
 
   private <InputT, OutputT> TransformEvaluator<InputT> createSingleEvaluator(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
index f316e19..77b3296 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagerTest.java
@@ -18,7 +18,9 @@
 
 package org.apache.beam.runners.direct;
 
+import static com.google.common.base.Preconditions.checkState;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.theInstance;
 import static org.junit.Assert.assertThat;
@@ -49,6 +51,8 @@ public class DoFnLifecycleManagerTest {
     TestFn obtained = (TestFn) mgr.get();
 
     assertThat(obtained, not(theInstance(fn)));
+    assertThat(obtained.setupCalled, is(true));
+    assertThat(obtained.teardownCalled, is(false));
   }
 
   @Test
@@ -57,6 +61,8 @@ public class DoFnLifecycleManagerTest {
     TestFn secondObtained = (TestFn) mgr.get();
 
     assertThat(obtained, theInstance(secondObtained));
+    assertThat(obtained.setupCalled, is(true));
+    assertThat(obtained.teardownCalled, is(false));
   }
 
   @Test
@@ -74,6 +80,7 @@ public class DoFnLifecycleManagerTest {
     }
 
     for (TestFn fn : fns) {
+      assertThat(fn.setupCalled, is(true));
       int sameInstances = 0;
       for (TestFn otherFn : fns) {
         if (otherFn == fn) {
@@ -90,10 +97,33 @@ public class DoFnLifecycleManagerTest {
     mgr.remove();
 
     assertThat(obtained, not(theInstance(fn)));
+    assertThat(obtained.setupCalled, is(true));
+    assertThat(obtained.teardownCalled, is(true));
 
     assertThat(mgr.get(), not(Matchers.<OldDoFn<?, ?>>theInstance(obtained)));
   }
 
+  @Test
+  public void teardownAllOnRemoveAll() throws Exception {
+    CountDownLatch startSignal = new CountDownLatch(1);
+    ExecutorService executor = Executors.newCachedThreadPool();
+    List<Future<TestFn>> futures = new ArrayList<>();
+    for (int i = 0; i < 10; i++) {
+      futures.add(executor.submit(new GetFnCallable(mgr, startSignal)));
+    }
+    startSignal.countDown();
+    List<TestFn> fns = new ArrayList<>();
+    for (Future<TestFn> future : futures) {
+      fns.add(future.get(1L, TimeUnit.SECONDS));
+    }
+    mgr.removeAll();
+
+    for (TestFn fn : fns) {
+      assertThat(fn.setupCalled, is(true));
+      assertThat(fn.teardownCalled, is(true));
+    }
+  }
+
   private static class GetFnCallable implements Callable<TestFn> {
     private final DoFnLifecycleManager mgr;
     private final CountDownLatch startSignal;
@@ -112,8 +142,27 @@ public class DoFnLifecycleManagerTest {
 
 
   private static class TestFn extends OldDoFn<Object, Object> {
+    boolean setupCalled = false;
+    boolean teardownCalled = false;
+
+    @Override
+    public void setup() {
+      checkState(!setupCalled);
+      checkState(!teardownCalled);
+
+      setupCalled = true;
+    }
+
     @Override
     public void processElement(ProcessContext c) throws Exception {
     }
+
+    @Override
+    public void teardown() {
+      checkState(setupCalled);
+      checkState(!teardownCalled);
+
+      teardownCalled = true;
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java
new file mode 100644
index 0000000..8be3d52
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DoFnLifecycleManagersTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.equalTo;
+
+import org.apache.beam.sdk.transforms.OldDoFn;
+
+import com.google.common.collect.ImmutableList;
+
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.hamcrest.Matcher;
+import org.hamcrest.Matchers;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+/**
+ * Tests for {@link DoFnLifecycleManagers}.
+ */
+@RunWith(JUnit4.class)
+public class DoFnLifecycleManagersTest {
+  @Rule public ExpectedException thrown = ExpectedException.none();
+
+  @Test
+  public void removeAllWhenManagersThrowSuppressesAndThrows() throws Exception {
+    DoFnLifecycleManager first = DoFnLifecycleManager.of(new ThrowsInCleanupFn("foo"));
+    DoFnLifecycleManager second = DoFnLifecycleManager.of(new ThrowsInCleanupFn("bar"));
+    DoFnLifecycleManager third = DoFnLifecycleManager.of(new ThrowsInCleanupFn("baz"));
+    first.get();
+    second.get();
+    third.get();
+
+    final Collection<Matcher<? super Throwable>> suppressions = new ArrayList<>();
+    suppressions.add(new ThrowableMessageMatcher("foo"));
+    suppressions.add(new ThrowableMessageMatcher("bar"));
+    suppressions.add(new ThrowableMessageMatcher("baz"));
+
+    thrown.expect(
+        new BaseMatcher<Exception>() {
+          @Override
+          public void describeTo(Description description) {
+            description
+                .appendText("Exception suppressing ")
+                .appendList("[", ", ", "]", suppressions);
+          }
+
+          @Override
+          public boolean matches(Object item) {
+            if (!(item instanceof Exception)) {
+              return false;
+            }
+            Exception that = (Exception) item;
+            return Matchers.containsInAnyOrder(suppressions)
+                .matches(ImmutableList.copyOf(that.getSuppressed()));
+          }
+        });
+
+    DoFnLifecycleManagers.removeAllFromManagers(ImmutableList.of(first, second, third));
+  }
+
+  @Test
+  public void whenManagersSucceedSucceeds() throws Exception {
+    DoFnLifecycleManager first = DoFnLifecycleManager.of(new EmptyFn());
+    DoFnLifecycleManager second = DoFnLifecycleManager.of(new EmptyFn());
+    DoFnLifecycleManager third = DoFnLifecycleManager.of(new EmptyFn());
+    first.get();
+    second.get();
+    third.get();
+
+    DoFnLifecycleManagers.removeAllFromManagers(ImmutableList.of(first, second, third));
+  }
+
+  private static class ThrowsInCleanupFn extends OldDoFn<Object, Object> {
+    private final String message;
+
+    private ThrowsInCleanupFn(String message) {
+      this.message = message;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+
+    @Override
+    public void teardown() throws Exception {
+      throw new Exception(message);
+    }
+  }
+
+
+  private static class ThrowableMessageMatcher extends BaseMatcher<Throwable> {
+    private final Matcher<String> messageMatcher;
+
+    public ThrowableMessageMatcher(String message) {
+      this.messageMatcher = equalTo(message);
+    }
+
+    @Override
+    public boolean matches(Object item) {
+      if (!(item instanceof Throwable)) {
+        return false;
+      }
+      Throwable that = (Throwable) item;
+      return messageMatcher.matches(that.getMessage());
+    }
+
+    @Override
+    public void describeTo(Description description) {
+      description.appendText("a throwable with a message ").appendDescriptionOf(messageMatcher);
+    }
+  }
+
+
+  private static class EmptyFn extends OldDoFn<Object, Object> {
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
index a4af1b0..fdf1e59 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
@@ -25,6 +25,7 @@ import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.PCollectionView;
 
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.util.Collector;
 
 import java.util.Map;
@@ -86,7 +87,7 @@ public class FlinkDoFnFunction<InputT, OutputT>
       // side inputs and window access also only works if an element
       // is in only one window
       for (WindowedValue<InputT> value : values) {
-        for (WindowedValue<InputT> explodedValue: value.explodeWindows()) {
+        for (WindowedValue<InputT> explodedValue : value.explodeWindows()) {
           context = context.forWindowedValue(value);
           doFn.processElement(context);
         }
@@ -99,4 +100,13 @@ public class FlinkDoFnFunction<InputT, OutputT>
     this.doFn.finishBundle(context);
   }
 
+  @Override
+  public void open(Configuration parameters) throws Exception {
+    doFn.setup();
+  }
+
+  @Override
+  public void close() throws Exception {
+    doFn.teardown();
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
index 6e673fc..5013b90 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
@@ -27,6 +27,7 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.util.Collector;
 
 import java.util.Map;
@@ -75,14 +76,15 @@ public class FlinkMultiOutputDoFnFunction<InputT, OutputT>
       Iterable<WindowedValue<InputT>> values,
       Collector<WindowedValue<RawUnionValue>> out) throws Exception {
 
-    FlinkProcessContext<InputT, OutputT> context = new FlinkMultiOutputProcessContext<>(
-        serializedOptions.getPipelineOptions(),
-        getRuntimeContext(),
-        doFn,
-        windowingStrategy,
-        out,
-        outputMap,
-        sideInputs);
+    FlinkProcessContext<InputT, OutputT> context =
+        new FlinkMultiOutputProcessContext<>(
+            serializedOptions.getPipelineOptions(),
+            getRuntimeContext(),
+            doFn,
+            windowingStrategy,
+            out,
+            outputMap,
+            sideInputs);
 
     this.doFn.startBundle(context);
 
@@ -97,14 +99,23 @@ public class FlinkMultiOutputDoFnFunction<InputT, OutputT>
       // side inputs and window access also only works if an element
       // is in only one window
       for (WindowedValue<InputT> value : values) {
-        for (WindowedValue<InputT> explodedValue: value.explodeWindows()) {
+        for (WindowedValue<InputT> explodedValue : value.explodeWindows()) {
           context = context.forWindowedValue(value);
           doFn.processElement(context);
         }
       }
     }
 
-
     this.doFn.finishBundle(context);
   }
+
+  @Override
+  public void open(Configuration parameters) throws Exception {
+    doFn.setup();
+  }
+
+  @Override
+  public void close() throws Exception {
+    doFn.teardown();
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
index 74ec66a..a9dd865 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java
@@ -70,6 +70,7 @@ public abstract class FlinkAbstractParDoWrapper<IN, OUTDF, OUTFL> extends RichFl
 
   @Override
   public void open(Configuration parameters) throws Exception {
+    doFn.setup();
   }
 
   @Override
@@ -78,6 +79,7 @@ public abstract class FlinkAbstractParDoWrapper<IN, OUTDF, OUTFL> extends RichFl
       // we have initialized the context
       this.doFn.finishBundle(this.context);
     }
+    this.doFn.teardown();
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
index 103a12b..4fddb53 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java
@@ -252,6 +252,7 @@ public class FlinkGroupAlsoByWindowWrapper<K, VIN, VACC, VOUT>
   @Override
   public void open() throws Exception {
     super.open();
+    operator.setup();
     this.context = new ProcessContext(operator, new TimestampedCollector<>(output), this.timerInternals);
     operator.startBundle(context);
   }
@@ -351,6 +352,7 @@ public class FlinkGroupAlsoByWindowWrapper<K, VIN, VACC, VOUT>
   @Override
   public void close() throws Exception {
     operator.finishBundle(context);
+    operator.teardown();
     super.close();
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index f4ce516..c08d185 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -24,6 +24,8 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.TupleTag;
 
 import org.apache.spark.api.java.function.FlatMapFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Iterator;
 import java.util.LinkedList;
@@ -40,6 +42,8 @@ public class DoFnFunction<InputT, OutputT>
     implements FlatMapFunction<Iterator<WindowedValue<InputT>>,
     WindowedValue<OutputT>> {
   private final OldDoFn<InputT, OutputT> mFunction;
+  private static final Logger LOG = LoggerFactory.getLogger(DoFnFunction.class);
+
   private final SparkRuntimeContext mRuntimeContext;
   private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;
 
@@ -61,8 +65,23 @@ public class DoFnFunction<InputT, OutputT>
       Exception {
     ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs);
     ctxt.setup();
-    mFunction.startBundle(ctxt);
-    return ctxt.getOutputIterable(iter, mFunction);
+    try {
+      mFunction.setup();
+      mFunction.startBundle(ctxt);
+      return ctxt.getOutputIterable(iter, mFunction);
+    } catch (Exception e) {
+      try {
+        // this teardown handles exceptions encountered in setup() and startBundle(). teardown
+        // after execution or due to exceptions in process element is called in the iterator
+        // produced by ctxt.getOutputIterable returned from this method.
+        mFunction.teardown();
+      } catch (Exception teardownException) {
+        LOG.error(
+            "Suppressing exception while tearing down Function {}", mFunction, teardownException);
+        e.addSuppressed(teardownException);
+      }
+      throw e;
+    }
   }
 
   private class ProcCtxt extends SparkProcessContext<InputT, OutputT, WindowedValue<OutputT>> {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index e33578d..abf0e83 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -65,6 +65,7 @@ class MultiDoFnFunction<InputT, OutputT>
   public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>>
       call(Iterator<WindowedValue<InputT>> iter) throws Exception {
     ProcCtxt ctxt = new ProcCtxt(mFunction, mRuntimeContext, mSideInputs);
+    mFunction.setup();
     mFunction.startBundle(ctxt);
     ctxt.setup();
     return ctxt.getOutputIterable(iter, mFunction);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index 2f06a1c..1cdbd92 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -238,6 +238,7 @@ public abstract class SparkProcessContext<InputT, OutputT, ValueT>
           try {
             doFn.processElement(SparkProcessContext.this);
           } catch (Exception e) {
+            handleProcessingException(e);
             throw new SparkProcessException(e);
           }
           outputIterator = getOutputIterator();
@@ -249,15 +250,31 @@ public abstract class SparkProcessContext<InputT, OutputT, ValueT>
               calledFinish = true;
               doFn.finishBundle(SparkProcessContext.this);
             } catch (Exception e) {
+              handleProcessingException(e);
               throw new SparkProcessException(e);
             }
             outputIterator = getOutputIterator();
             continue; // try to consume outputIterator from start of loop
           }
+          try {
+            doFn.teardown();
+          } catch (Exception e) {
+            LOG.error(
+                "Suppressing teardown exception that occurred after processing entire input", e);
+          }
           return endOfData();
         }
       }
     }
+
+    private void handleProcessingException(Exception e) {
+      try {
+        doFn.teardown();
+      } catch (Exception e1) {
+        LOG.error("Exception while cleaning up DoFn", e1);
+        e.addSuppressed(e1);
+      }
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index a06467e..80b67af 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -342,6 +342,20 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
 
   /////////////////////////////////////////////////////////////////////////////
 
+
+  /**
+   * Annotation for the method to use to prepare an instance for processing bundles of elements. The
+   * method annotated with this must satisfy the following constraints
+   * <ul>
+   *   <li>It must have zero arguments.
+   * </ul>
+   */
+  @Documented
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.METHOD)
+  public @interface Setup {
+  }
+
   /**
    * Annotation for the method to use to prepare an instance for processing a batch of elements.
    * The method annotated with this must satisfy the following constraints:
@@ -371,7 +385,7 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
   public @interface ProcessElement {}
 
   /**
-   * Annotation for the method to use to prepare an instance for processing a batch of elements.
+   * Annotation for the method to use to finish processing a batch of elements.
    * The method annotated with this must satisfy the following constraints:
    * <ul>
    *   <li>It must have at least one argument.
@@ -383,6 +397,21 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
   @Target(ElementType.METHOD)
   public @interface FinishBundle {}
 
+
+  /**
+   * Annotation for the method to use to clean up this instance after processing bundles of
+   * elements. No other method will be called after a call to the annotated method is made.
+   * The method annotated with this must satisfy the following constraint:
+   * <ul>
+   *   <li>It must have zero arguments.
+   * </ul>
+   */
+  @Documented
+  @Retention(RetentionPolicy.RUNTIME)
+  @Target(ElementType.METHOD)
+  public @interface Teardown {
+  }
+
   /**
    * Returns an {@link Aggregator} with aggregation logic specified by the
    * {@link CombineFn} argument. The name provided must be unique across

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
index 3dfda55..bf04041 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
@@ -17,11 +17,15 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import static com.google.common.base.Preconditions.checkState;
+
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory;
 import org.apache.beam.sdk.transforms.DoFn.FinishBundle;
 import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
+import org.apache.beam.sdk.transforms.DoFn.Setup;
 import org.apache.beam.sdk.transforms.DoFn.StartBundle;
+import org.apache.beam.sdk.transforms.DoFn.Teardown;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -302,6 +306,15 @@ public abstract class DoFnReflector {
         new TypeParameter<OutputT>() {});
   }
 
+  @VisibleForTesting
+  static void verifyLifecycleMethodArguments(Method m) {
+    if (m == null) {
+      return;
+    }
+    checkState(void.class.equals(m.getReturnType()), "%s must have void return type", format(m));
+    checkState(m.getGenericParameterTypes().length == 0, "%s must take zero arguments", format(m));
+  }
+
   /**
    * Verify the method arguments for a given {@link DoFn} method.
    *
@@ -392,6 +405,8 @@ public abstract class DoFnReflector {
 
   /** Interface for invoking the {@code OldDoFn} processing methods. */
   public interface DoFnInvoker<InputT, OutputT>  {
+    /** Invoke {@link OldDoFn#setup} on the bound {@code OldDoFn}. */
+    void invokeSetup();
     /** Invoke {@link OldDoFn#startBundle} on the bound {@code OldDoFn}. */
     void invokeStartBundle(
         DoFn<InputT, OutputT>.Context c,
@@ -401,6 +416,9 @@ public abstract class DoFnReflector {
         DoFn<InputT, OutputT>.Context c,
         ExtraContextFactory<InputT, OutputT> extra);
 
+    /** Invoke {@link OldDoFn#teardown()} on the bound {@code DoFn}. */
+    void invokeTeardown();
+
     /** Invoke {@link OldDoFn#processElement} on the bound {@code OldDoFn}. */
     public void invokeProcessElement(
         DoFn<InputT, OutputT>.ProcessContext c,
@@ -412,9 +430,11 @@ public abstract class DoFnReflector {
    */
   private static class GenericDoFnReflector extends DoFnReflector {
 
+    private final Method setup;
     private final Method startBundle;
     private final Method processElement;
     private final Method finishBundle;
+    private final Method teardown;
     private final List<AdditionalParameter> processElementArgs;
     private final List<AdditionalParameter> startBundleArgs;
     private final List<AdditionalParameter> finishBundleArgs;
@@ -424,13 +444,17 @@ public abstract class DoFnReflector {
         @SuppressWarnings("rawtypes") Class<? extends DoFn> fn) {
       // Locate the annotated methods
       this.processElement = findAnnotatedMethod(ProcessElement.class, fn, true);
+      this.setup = findAnnotatedMethod(Setup.class, fn, false);
       this.startBundle = findAnnotatedMethod(StartBundle.class, fn, false);
       this.finishBundle = findAnnotatedMethod(FinishBundle.class, fn, false);
+      this.teardown = findAnnotatedMethod(Teardown.class, fn, false);
 
       // Verify that their method arguments satisfy our conditions.
       this.processElementArgs = verifyProcessMethodArguments(processElement);
       this.startBundleArgs = verifyBundleMethodArguments(startBundle);
       this.finishBundleArgs = verifyBundleMethodArguments(finishBundle);
+      verifyLifecycleMethodArguments(setup);
+      verifyLifecycleMethodArguments(teardown);
 
       this.constructor = createInvokerConstructor(fn);
     }
@@ -552,8 +576,17 @@ public abstract class DoFnReflector {
           .intercept(InvokerDelegation.create(
               startBundle, BeforeDelegation.INVOKE_PREPARE_FOR_PROCESSING, startBundleArgs))
           .method(ElementMatchers.named("invokeFinishBundle"))
-          .intercept(InvokerDelegation.create(
-              finishBundle, BeforeDelegation.NOOP, finishBundleArgs));
+          .intercept(InvokerDelegation.create(finishBundle,
+              BeforeDelegation.NOOP,
+              finishBundleArgs))
+          .method(ElementMatchers.named("invokeSetup"))
+          .intercept(InvokerDelegation.create(setup,
+              BeforeDelegation.NOOP,
+              Collections.<AdditionalParameter>emptyList()))
+          .method(ElementMatchers.named("invokeTeardown"))
+          .intercept(InvokerDelegation.create(teardown,
+              BeforeDelegation.NOOP,
+              Collections.<AdditionalParameter>emptyList()));
 
       @SuppressWarnings("unchecked")
       Class<? extends DoFnInvoker<?, ?>> dynamicClass = (Class<? extends DoFnInvoker<?, ?>>) builder
@@ -736,6 +769,11 @@ public abstract class DoFnReflector {
     }
 
     @Override
+    public void setup() throws Exception {
+      invoker.invokeSetup();
+    }
+
+    @Override
     public void startBundle(OldDoFn<InputT, OutputT>.Context c) throws Exception {
       ContextAdapter<InputT, OutputT> adapter = new ContextAdapter<>(fn, c);
       invoker.invokeStartBundle(adapter, adapter);
@@ -748,6 +786,11 @@ public abstract class DoFnReflector {
     }
 
     @Override
+    public void teardown() {
+      invoker.invokeTeardown();
+    }
+
+    @Override
     public void processElement(OldDoFn<InputT, OutputT>.ProcessContext c) throws Exception {
       ProcessContextAdapter<InputT, OutputT> adapter = new ProcessContextAdapter<>(fn, c);
       invoker.invokeProcessElement(adapter, adapter);
@@ -940,15 +983,20 @@ public abstract class DoFnReflector {
           new MethodDescription.ForLoadedMethod(target)).resolve(instrumentedMethod);
       ParameterList<?> params = targetMethod.getParameters();
 
-      // Instructions to setup the parameters for the call
-      ArrayList<StackManipulation> parameters = new ArrayList<>(args.size() + 1);
-      // 1. The first argument in the delegate method must be the context. This corresponds to
-      //    the first argument in the instrumented method, so copy that.
-      parameters.add(MethodVariableAccess.of(
-          params.get(0).getType().getSuperClass()).loadOffset(1));
-      // 2. For each of the extra arguments push the appropriate value.
-      for (AdditionalParameter arg : args) {
-        parameters.add(pushArgument(arg, instrumentedMethod));
+      List<StackManipulation> parameters;
+      if (!params.isEmpty()) {
+        // Instructions to setup the parameters for the call
+        parameters = new ArrayList<>(args.size() + 1);
+        // 1. The first argument in the delegate method must be the context. This corresponds to
+        //    the first argument in the instrumented method, so copy that.
+        parameters.add(MethodVariableAccess.of(params.get(0).getType().getSuperClass())
+            .loadOffset(1));
+        // 2. For each of the extra arguments push the appropriate value.
+        for (AdditionalParameter arg : args) {
+          parameters.add(pushArgument(arg, instrumentedMethod));
+        }
+      } else {
+        parameters = Collections.emptyList();
       }
 
       return new StackManipulation.Compound(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
index 443599a..84cd997 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
@@ -339,6 +339,17 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
   private boolean aggregatorsAreFinal;
 
   /**
+   * Prepares this {@link DoFn} instance for processing bundles.
+   *
+   * <p>{@link #setup()} will be called at most once per {@link DoFn} instance, and before any other
+   * {@link DoFn} method is called.
+   *
+   * <p>By default, does nothing.
+   */
+  public void setup() throws Exception {
+  }
+
+  /**
    * Prepares this {@code OldDoFn} instance for processing a batch of elements.
    *
    * <p>By default, does nothing.
@@ -373,6 +384,20 @@ public abstract class OldDoFn<InputT, OutputT> implements Serializable, HasDispl
   }
 
   /**
+   * Cleans up this {@link DoFn}.
+   *
+   * <p>{@link #teardown()} will be called before the {@link PipelineRunner} discards a {@link DoFn}
+   * instance, including due to another {@link DoFn} method throwing an {@link Exception}. No other
+   * {@link DoFn} methods will be called after a call to {@link #teardown()}.
+   *
+   * <p>By default, does nothing.
+   */
+  public void teardown() throws Exception {
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  /**
    * {@inheritDoc}
    *
    * <p>By default, does not register any display data. Implementors may override this method

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index ca6d9b2..aa57531 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -67,11 +67,11 @@ import java.util.List;
  * For each bundle of input elements processing proceeds as follows:
  *
  * <ol>
- *   <li>If required, a fresh instance of the argument {@link OldDoFn} is created
- *     on a worker. This may be through deserialization or other means. A
- *     {@link PipelineRunner} may reuse {@link OldDoFn} instances for multiple bundles.
- *     A {@link OldDoFn} that has terminated abnormally (by throwing an {@link Exception}
- *     will never be reused.</li>
+ *   <li>If required, a fresh instance of the argument {@link DoFn} is created
+ *     on a worker, and {@link DoFn#setup()} is called on this instance. This may be through
+ *     deserialization or other means. A {@link PipelineRunner} may reuse {@link DoFn} instances for
+ *     multiple bundles. A {@link DoFn} that has terminated abnormally (by throwing an
+ *     {@link Exception}) will never be reused.</li>
  *   <li>The {@link OldDoFn OldDoFn's} {@link OldDoFn#startBundle} method is called to
  *     initialize it. If this method is not overridden, the call may be optimized
  *     away.</li>
@@ -83,6 +83,11 @@ import java.util.List;
  *     {@link OldDoFn#finishBundle}
  *     until a new call to {@link OldDoFn#startBundle} has occurred.
  *     If this method is not overridden, this call may be optimized away.</li>
+ *   <li>If any of {@link DoFn#setup}, {@link DoFn#startBundle}, {@link DoFn#processElement} or
+ *     {@link DoFn#finishBundle} throw an exception, {@link DoFn#teardown} will be called on the
+ *     {@link DoFn} instance.</li>
+ *   <li>If a runner will no longer use a {@link DoFn}, {@link DoFn#teardown()} will be called on
+ *     the discarded instance.</li>
  * </ol>
  *
  * Each of the calls to any of the {@link OldDoFn OldDoFn's} processing

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
index c47e0cf..e05e5e2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
@@ -25,6 +25,8 @@ import org.apache.beam.sdk.transforms.DoFn.Context;
 import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
 import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
+import org.apache.beam.sdk.transforms.DoFn.Setup;
+import org.apache.beam.sdk.transforms.DoFn.Teardown;
 import org.apache.beam.sdk.transforms.dofnreflector.DoFnReflectorTestHelper;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.UserCodeException;
@@ -53,6 +55,8 @@ public class DoFnReflectorTest {
     public boolean wasProcessElementInvoked = false;
     public boolean wasStartBundleInvoked = false;
     public boolean wasFinishBundleInvoked = false;
+    public boolean wasSetupInvoked = false;
+    public boolean wasTeardownInvoked = false;
     private final String name;
 
     public Invocations(String name) {
@@ -144,6 +148,33 @@ public class DoFnReflectorTest {
     }
   }
 
+  private void checkInvokeSetupWorks(DoFnReflector r, Invocations... invocations) throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called setup on " + invocation.name,
+          invocation.wasSetupInvoked);
+    }
+    r.bindInvoker(fn).invokeSetup();
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called setup on " + invocation.name,
+          invocation.wasSetupInvoked);
+    }
+  }
+
+  private void checkInvokeTeardownWorks(DoFnReflector r, Invocations... invocations)
+      throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called teardown on " + invocation.name,
+          invocation.wasTeardownInvoked);
+    }
+    r.bindInvoker(fn).invokeTeardown();
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called teardown on " + invocation.name,
+          invocation.wasTeardownInvoked);
+    }
+  }
+
   @Test
   public void testDoFnWithNoExtraContext() throws Exception {
     final Invocations invocations = new Invocations("AnonymousClass");
@@ -325,6 +356,40 @@ public class DoFnReflectorTest {
   }
 
   @Test
+  public void testDoFnWithSetupTeardown() throws Exception {
+    final Invocations invocations = new Invocations("AnonymousClass");
+    DoFnReflector reflector = underTest(new DoFn<String, String>() {
+      @ProcessElement
+      public void processElement(@SuppressWarnings("unused") ProcessContext c) {}
+
+      @StartBundle
+      public void startBundle(Context c) {
+        invocations.wasStartBundleInvoked = true;
+        assertSame(c, mockContext);
+      }
+
+      @FinishBundle
+      public void finishBundle(Context c) {
+        invocations.wasFinishBundleInvoked = true;
+        assertSame(c, mockContext);
+      }
+
+      @Setup
+      public void before() {
+        invocations.wasSetupInvoked = true;
+      }
+
+      @Teardown
+      public void after() {
+        invocations.wasTeardownInvoked = true;
+      }
+    });
+
+    checkInvokeSetupWorks(reflector, invocations);
+    checkInvokeTeardownWorks(reflector, invocations);
+  }
+
+  @Test
   public void testNoProcessElement() throws Exception {
     thrown.expect(IllegalStateException.class);
     thrown.expectMessage("No method annotated with @ProcessElement found");

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/12abb1b0/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 7fe053c..8460124 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -24,17 +24,18 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.include
 import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray;
 import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString;
 import static org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray;
-
 import static com.google.common.base.Preconditions.checkNotNull;
-
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
 import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
 
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.AtomicCoder;
@@ -53,6 +54,7 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TimestampedValue;
@@ -60,6 +62,7 @@ import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
+
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -77,6 +80,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Tests for ParDo.
@@ -169,8 +173,10 @@ public class ParDoTest implements Serializable {
   }
 
   static class TestDoFn extends DoFn<Integer, String> {
-    enum State { UNSTARTED, STARTED, PROCESSING, FINISHED }
-    State state = State.UNSTARTED;
+    enum State {NOT_SET_UP, UNSTARTED, STARTED, PROCESSING, FINISHED}
+
+
+    State state = State.NOT_SET_UP;
 
     final List<PCollectionView<Integer>> sideInputViews = new ArrayList<>();
     final List<TupleTag<String>> sideOutputTupleTags = new ArrayList<>();
@@ -184,6 +190,12 @@ public class ParDoTest implements Serializable {
       this.sideOutputTupleTags.addAll(sideOutputTupleTags);
     }
 
+    @Setup
+    public void prepare() {
+      assertEquals(State.NOT_SET_UP, state);
+      state = State.UNSTARTED;
+    }
+
     @StartBundle
     public void startBundle(Context c) {
       assertEquals(State.UNSTARTED, state);
@@ -1463,4 +1475,404 @@ public class ParDoTest implements Serializable {
     assertThat(displayData, includesDisplayDataFrom(fn));
     assertThat(displayData, hasDisplayItem("fn", fn.getClass()));
   }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnCallSequence() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>()));
+
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnCallSequenceMulti() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingOldFn<Integer>())
+                .withOutputTags(new TupleTag<Integer>() {}, TupleTagList.empty()));
+
+    p.run();
+  }
+
+  private static class CallSequenceEnforcingOldFn<T> extends OldDoFn<T, T> {
+    private boolean setupCalled = false;
+    private int startBundleCalls = 0;
+    private int finishBundleCalls = 0;
+    private boolean teardownCalled = false;
+
+    @Override
+    public void setup() {
+      assertThat("setup should not be called twice", setupCalled, is(false));
+      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
+      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
+      assertThat("setup should be called before teardown", teardownCalled, is(false));
+      setupCalled = true;
+    }
+
+    @Override
+    public void startBundle(Context c) {
+      assertThat("setup should have been called", setupCalled, is(true));
+      assertThat(
+          "Even number of startBundle and finishBundle calls in startBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      startBundleCalls++;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat(
+          "there should be one startBundle call with no call to finishBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+    }
+
+    @Override
+    public void finishBundle(Context c) {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat(
+          "there should be one bundle that has been started but not finished",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      finishBundleCalls++;
+    }
+
+    @Override
+    public void teardown() {
+      assertThat(setupCalled, is(true));
+      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
+      assertThat(teardownCalled, is(false));
+      teardownCalled = true;
+    }
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnWithContextCallSequence() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>()));
+
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testFnWithContextCallSequenceMulti() {
+    TestPipeline p = TestPipeline.create();
+    PCollectionList.of(p.apply("Impolite", Create.of(1, 2, 4)))
+        .and(p.apply("Polite", Create.of(3, 5, 6, 7)))
+        .apply(Flatten.<Integer>pCollections())
+        .apply(ParDo.of(new CallSequenceEnforcingFn<Integer>())
+            .withOutputTags(new TupleTag<Integer>() {
+            }, TupleTagList.empty()));
+
+    p.run();
+  }
+
+  private static class CallSequenceEnforcingFn<T> extends DoFn<T, T> {
+    private boolean setupCalled = false;
+    private int startBundleCalls = 0;
+    private int finishBundleCalls = 0;
+    private boolean teardownCalled = false;
+
+    @Setup
+    public void before() {
+      assertThat("setup should not be called twice", setupCalled, is(false));
+      assertThat("setup should be called before startBundle", startBundleCalls, equalTo(0));
+      assertThat("setup should be called before finishBundle", finishBundleCalls, equalTo(0));
+      assertThat("setup should be called before teardown", teardownCalled, is(false));
+      setupCalled = true;
+    }
+
+    @StartBundle
+    public void begin(Context c) {
+      assertThat("setup should have been called", setupCalled, is(true));
+      assertThat("Even number of startBundle and finishBundle calls in startBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      startBundleCalls++;
+    }
+
+    @ProcessElement
+    public void process(ProcessContext c) throws Exception {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat("there should be one startBundle call with no call to finishBundle",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+    }
+
+    @FinishBundle
+    public void end(Context c) {
+      assertThat("startBundle should have been called", startBundleCalls, greaterThan(0));
+      assertThat("there should be one bundle that has been started but not finished",
+          startBundleCalls,
+          equalTo(finishBundleCalls + 1));
+      assertThat("teardown should not have been called", teardownCalled, is(false));
+      finishBundleCalls++;
+    }
+
+    @Teardown
+    public void after() {
+      assertThat(setupCalled, is(true));
+      assertThat(startBundleCalls, anyOf(equalTo(finishBundleCalls)));
+      assertThat(teardownCalled, is(false));
+      teardownCalled = true;
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInSetup() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInStartBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInProcessElement() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testTeardownCalledAfterExceptionInFinishBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
+    p
+        .apply(Create.of(1, 2, 3))
+        .apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat(
+          "Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInSetup() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.SETUP);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInStartBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.START_BUNDLE);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInProcessElement() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.PROCESS_ELEMENT);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWithContextTeardownCalledAfterExceptionInFinishBundle() {
+    TestPipeline p = TestPipeline.create();
+    ExceptionThrowingOldFn fn = new ExceptionThrowingOldFn(MethodForException.FINISH_BUNDLE);
+    p.apply(Create.of(1, 2, 3)).apply(ParDo.of(fn));
+    try {
+      p.run();
+      fail("Pipeline should have failed with an exception");
+    } catch (Exception e) {
+      assertThat("Function should have been torn down after exception",
+          ExceptionThrowingOldFn.teardownCalled.get(),
+          is(true));
+    }
+  }
+
+  private static class ExceptionThrowingOldFn extends OldDoFn<Object, Object> {
+    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
+
+    private final MethodForException toThrow;
+    private boolean thrown;
+
+    private ExceptionThrowingOldFn(MethodForException toThrow) {
+      this.toThrow = toThrow;
+    }
+
+    @Override
+    public void setup() throws Exception {
+      throwIfNecessary(MethodForException.SETUP);
+    }
+
+    @Override
+    public void startBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.START_BUNDLE);
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
+    }
+
+    @Override
+    public void finishBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.FINISH_BUNDLE);
+    }
+
+    private void throwIfNecessary(MethodForException method) throws Exception {
+      if (toThrow == method && !thrown) {
+        thrown = true;
+        throw new Exception("Hasn't yet thrown");
+      }
+    }
+
+    @Override
+    public void teardown() {
+      if (!thrown) {
+        fail("Excepted to have a processing method throw an exception");
+      }
+      teardownCalled.set(true);
+    }
+  }
+
+
+  private static class ExceptionThrowingFn extends DoFn<Object, Object> {
+    static AtomicBoolean teardownCalled = new AtomicBoolean(false);
+
+    private final MethodForException toThrow;
+    private boolean thrown;
+
+    private ExceptionThrowingFn(MethodForException toThrow) {
+      this.toThrow = toThrow;
+    }
+
+    @Setup
+    public void before() throws Exception {
+      throwIfNecessary(MethodForException.SETUP);
+    }
+
+    @StartBundle
+    public void preBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.START_BUNDLE);
+    }
+
+    @ProcessElement
+    public void perElement(ProcessContext c) throws Exception {
+      throwIfNecessary(MethodForException.PROCESS_ELEMENT);
+    }
+
+    @FinishBundle
+    public void postBundle(Context c) throws Exception {
+      throwIfNecessary(MethodForException.FINISH_BUNDLE);
+    }
+
+    private void throwIfNecessary(MethodForException method) throws Exception {
+      if (toThrow == method && !thrown) {
+        thrown = true;
+        throw new Exception("Hasn't yet thrown");
+      }
+    }
+
+    @Teardown
+    public void after() {
+      if (!thrown) {
+        fail("Excepted to have a processing method throw an exception");
+      }
+      teardownCalled.set(true);
+    }
+  }
+
+  private enum MethodForException {
+    SETUP,
+    START_BUNDLE,
+    PROCESS_ELEMENT,
+    FINISH_BUNDLE
+  }
 }


Mime
View raw message