beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [1/2] incubator-beam git commit: [BEAM-37] DoFnReflector: Add invoker interface and generate code
Date Fri, 15 Jul 2016 00:03:49 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/master 779fb1a6c -> 79c26d9c1


[BEAM-37] DoFnReflector: Add invoker interface and generate code

The method to call for a DoFnWithContext requires reflection since the
shape of the parameters may change. Doing so in each processElement call
puts this refelection in the hot path.

This PR introduces a DoFnInvoker interface which is bound to a specific
DoFnWithContext and delegates the three important methods (startBundle,
processElement, finishBundle).

It uses byte-buddy to generate a simple trampoline implementation of
the DoFnInvoker class for each type of DoFnWithContext.

This leads to 2-3x better performance in micro-benchmarks of method
dispatching.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/2b47919c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/2b47919c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/2b47919c

Branch: refs/heads/master
Commit: 2b47919c18d6485c3bb3df8452bd67c940f00d65
Parents: 2a30f52
Author: Ben Chambers <bchambers@google.com>
Authored: Wed Jun 22 06:47:23 2016 -0700
Committer: Kenneth Knowles <klk@google.com>
Committed: Thu Jul 14 16:41:33 2016 -0700

----------------------------------------------------------------------
 pom.xml                                         |   6 +
 sdks/java/core/pom.xml                          |   5 +
 .../beam/sdk/transforms/DoFnReflector.java      | 717 ++++++++++++++-----
 .../beam/sdk/transforms/DoFnReflectorTest.java  | 197 ++++-
 .../transforms/DoFnReflectorBenchmark.java      |   7 +-
 5 files changed, 724 insertions(+), 208 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b47919c/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 7089c2c..1513735 100644
--- a/pom.xml
+++ b/pom.xml
@@ -590,6 +590,12 @@
         <version>${slf4j.version}</version>
       </dependency>
 
+      <dependency>
+        <groupId>net.bytebuddy</groupId>
+        <artifactId>byte-buddy</artifactId>
+        <version>1.4.3</version>
+      </dependency>
+
       <!-- Testing -->
 
       <dependency>

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b47919c/sdks/java/core/pom.xml
----------------------------------------------------------------------
diff --git a/sdks/java/core/pom.xml b/sdks/java/core/pom.xml
index bda77cb..8b6fff7 100644
--- a/sdks/java/core/pom.xml
+++ b/sdks/java/core/pom.xml
@@ -434,6 +434,11 @@
     </dependency>
 
     <dependency>
+      <groupId>net.bytebuddy</groupId>
+      <artifactId>byte-buddy</artifactId>
+    </dependency>
+
+    <dependency>
       <groupId>org.apache.avro</groupId>
       <artifactId>avro</artifactId>
     </dependency>

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b47919c/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
index e711d04..116b64d 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java
@@ -34,15 +34,48 @@ import org.apache.beam.sdk.values.TypeDescriptor;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
 import com.google.common.collect.FluentIterable;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.reflect.TypeParameter;
 import com.google.common.reflect.TypeToken;
 
+import net.bytebuddy.ByteBuddy;
+import net.bytebuddy.description.field.FieldDescription;
+import net.bytebuddy.description.method.MethodDescription;
+import net.bytebuddy.description.method.ParameterList;
+import net.bytebuddy.description.modifier.FieldManifestation;
+import net.bytebuddy.description.modifier.Visibility;
+import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.dynamic.DynamicType;
+import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
+import net.bytebuddy.dynamic.scaffold.InstrumentedType;
+import net.bytebuddy.dynamic.scaffold.subclass.ConstructorStrategy;
+import net.bytebuddy.implementation.Implementation;
+import net.bytebuddy.implementation.MethodCall.MethodLocator;
+import net.bytebuddy.implementation.StubMethod;
+import net.bytebuddy.implementation.bind.MethodDelegationBinder.MethodInvoker;
+import net.bytebuddy.implementation.bind.annotation.TargetMethodAnnotationDrivenBinder.TerminationHandler;
+import net.bytebuddy.implementation.bytecode.ByteCodeAppender;
+import net.bytebuddy.implementation.bytecode.Duplication;
+import net.bytebuddy.implementation.bytecode.StackManipulation;
+import net.bytebuddy.implementation.bytecode.Throw;
+import net.bytebuddy.implementation.bytecode.assign.Assigner;
+import net.bytebuddy.implementation.bytecode.member.FieldAccess;
+import net.bytebuddy.implementation.bytecode.member.MethodInvocation;
+import net.bytebuddy.implementation.bytecode.member.MethodReturn;
+import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess;
+import net.bytebuddy.jar.asm.Label;
+import net.bytebuddy.jar.asm.MethodVisitor;
+import net.bytebuddy.jar.asm.Opcodes;
+import net.bytebuddy.matcher.ElementMatchers;
+
 import org.joda.time.Instant;
 
 import java.io.IOException;
 import java.lang.annotation.Annotation;
+import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.lang.reflect.Modifier;
@@ -54,105 +87,116 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
+import java.util.List;
 import java.util.Map;
 
 import javax.annotation.Nullable;
 
+
 /**
  * Utility implementing the necessary reflection for working with {@link DoFnWithContext}s.
  */
 public abstract class DoFnReflector {
 
-  private interface ExtraContextInfo {
-    /**
-     * Create an instance of the given instance using the instance factory.
-     */
-    <InputT, OutputT> Object createInstance(
-        DoFnWithContext.ExtraContextFactory<InputT, OutputT> factory);
+  private static final String FN_DELEGATE_FIELD_NAME = "delegate";
 
-    /**
-     * Create the type token for the given type, filling in the generics.
-     */
-    <InputT, OutputT> TypeToken<?> tokenFor(TypeToken<InputT> in, TypeToken<OutputT> out);
+  private enum Availability {
+    /** Indicates parameters only available in {@code @ProcessElement} methods. */
+    PROCESS_ELEMENT_ONLY,
+    /** Indicates parameters available in all methods. */
+    EVERYWHERE;
   }
 
-  private static final Map<Class<?>, ExtraContextInfo> EXTRA_CONTEXTS = Collections.emptyMap();
-  private static final Map<Class<?>, ExtraContextInfo> EXTRA_PROCESS_CONTEXTS =
-      ImmutableMap.<Class<?>, ExtraContextInfo>builder()
-      .putAll(EXTRA_CONTEXTS)
-      .put(BoundedWindow.class, new ExtraContextInfo() {
-        @Override
-        public <InputT, OutputT> Object
-            createInstance(ExtraContextFactory<InputT, OutputT> factory) {
-          return factory.window();
-        }
+  /**
+   * Enumeration of the parameters available from the {@link ExtraContextFactory} to use as
+   * additional parameters for {@link DoFnWithContext} methods.
+   * <p>
+   * We don't rely on looking for properly annotated methods within {@link ExtraContextFactory}
+   * because erasure would make it impossible to completely fill in the type token for context
+   * parameters that depend on the input/output type.
+   */
+  private enum AdditionalParameter {
+
+    /** Any {@link BoundedWindow} parameter is populated by the window of the current element. */
+    WINDOW_OF_ELEMENT(Availability.PROCESS_ELEMENT_ONLY, BoundedWindow.class, "window") {
+      @Override
+      public <InputT, OutputT> TypeToken<?>
+          tokenFor(TypeToken<InputT> in, TypeToken<OutputT> out) {
+        return TypeToken.of(BoundedWindow.class);
+      }
+    },
+
+    WINDOWING_INTERNALS(Availability.PROCESS_ELEMENT_ONLY,
+        WindowingInternals.class, "windowingInternals") {
+      @Override
+      public <InputT, OutputT> TypeToken<?> tokenFor(
+          TypeToken<InputT> in, TypeToken<OutputT> out) {
+        return new TypeToken<WindowingInternals<InputT, OutputT>>() {}
+            .where(new TypeParameter<InputT>() {}, in)
+            .where(new TypeParameter<OutputT>() {}, out);
+      }
+    };
 
-        @Override
-        public <InputT, OutputT> TypeToken<?>
-            tokenFor(TypeToken<InputT> in, TypeToken<OutputT> out) {
-          return TypeToken.of(BoundedWindow.class);
-        }
-      })
-      .put(WindowingInternals.class, new ExtraContextInfo() {
-        @Override
-        public <InputT, OutputT> Object
-            createInstance(ExtraContextFactory<InputT, OutputT> factory) {
-          return factory.windowingInternals();
-        }
+    /**
+     * Create a type token representing the given parameter. May use the type token associated
+     * with the input and output types of the {@link DoFnWithContext}, depending on the extra
+     * context.
+     */
+    abstract <InputT, OutputT> TypeToken<?> tokenFor(
+        TypeToken<InputT> in, TypeToken<OutputT> out);
 
-        @Override
-        public <InputT, OutputT> TypeToken<?>
-            tokenFor(TypeToken<InputT> in, TypeToken<OutputT> out) {
-          return new TypeToken<WindowingInternals<InputT, OutputT>>() {
-            }
-          .where(new TypeParameter<InputT>() {}, in)
-          .where(new TypeParameter<OutputT>() {}, out);
-        }
-      })
-      .build();
+    private final Class<?> rawType;
+    private final Availability availability;
+    private final transient MethodDescription method;
 
-  /**
-   * @return true if the reflected {@link DoFnWithContext} uses a Single Window.
-   */
-  public abstract boolean usesSingleWindow();
+    private AdditionalParameter(Availability availability, Class<?> rawType, String method) {
+      this.availability = availability;
+      this.rawType = rawType;
+      try {
+        this.method = new MethodDescription.ForLoadedMethod(
+            ExtraContextFactory.class.getMethod(method));
+      } catch (NoSuchMethodException | SecurityException e) {
+        throw new RuntimeException(
+            "Unable to access method " + method + " on " + ExtraContextFactory.class, e);
+      }
+    }
+  }
 
-  /**
-   * Invoke the reflected {@link ProcessElement} method on the given instance.
-   *
-   * @param fn an instance of the {@link DoFnWithContext} to invoke {@link ProcessElement} on.
-   * @param c the {@link org.apache.beam.sdk.transforms.DoFnWithContext.ProcessContext}
-   *     to pass to {@link ProcessElement}.
-   */
-  public abstract <InputT, OutputT> void invokeProcessElement(
-      DoFnWithContext<InputT, OutputT> fn,
-      DoFnWithContext<InputT, OutputT>.ProcessContext c,
-      ExtraContextFactory<InputT, OutputT> extra);
+  private static final Map<Class<?>, AdditionalParameter> EXTRA_CONTEXTS;
+  private static final Map<Class<?>, AdditionalParameter> EXTRA_PROCESS_CONTEXTS;
+
+  static {
+    ImmutableMap.Builder<Class<?>, AdditionalParameter> everywhereBuilder =
+        ImmutableMap.<Class<?>, AdditionalParameter>builder();
+    ImmutableMap.Builder<Class<?>, AdditionalParameter> processElementBuilder =
+        ImmutableMap.<Class<?>, AdditionalParameter>builder();
+
+    for (AdditionalParameter value : AdditionalParameter.values()) {
+      switch (value.availability) {
+        case EVERYWHERE:
+          everywhereBuilder.put(value.rawType, value);
+          break;
+        case PROCESS_ELEMENT_ONLY:
+          processElementBuilder.put(value.rawType, value);
+          break;
+      }
+    }
 
-  /**
-   * Invoke the reflected {@link StartBundle} method on the given instance.
-   *
-   * @param fn an instance of the {@link DoFnWithContext} to invoke {@link StartBundle} on.
-   * @param c the {@link org.apache.beam.sdk.transforms.DoFnWithContext.Context}
-   *     to pass to {@link StartBundle}.
-   */
-  public <InputT, OutputT> void invokeStartBundle(
-     DoFnWithContext<InputT, OutputT> fn,
-     DoFnWithContext<InputT, OutputT>.Context c,
-     ExtraContextFactory<InputT, OutputT> extra) {
-    fn.prepareForProcessing();
+    EXTRA_CONTEXTS = everywhereBuilder.build();
+    EXTRA_PROCESS_CONTEXTS = processElementBuilder
+        // Process Element contexts include everything available everywhere
+        .putAll(EXTRA_CONTEXTS)
+        .build();
   }
 
   /**
-   * Invoke the reflected {@link FinishBundle} method on the given instance.
-   *
-   * @param fn an instance of the {@link DoFnWithContext} to invoke {@link FinishBundle} on.
-   * @param c the {@link org.apache.beam.sdk.transforms.DoFnWithContext.Context}
-   *     to pass to {@link FinishBundle}.
+   * @return true if the reflected {@link DoFnWithContext} uses a Single Window.
    */
-  public abstract <InputT, OutputT> void invokeFinishBundle(
-      DoFnWithContext<InputT, OutputT> fn,
-      DoFnWithContext<InputT, OutputT>.Context c,
-      ExtraContextFactory<InputT, OutputT> extra);
+  public abstract boolean usesSingleWindow();
+
+  /** Create an {@link DoFnInvoker} bound to the given {@link DoFn}. */
+  public abstract <InputT, OutputT> DoFnInvoker<InputT, OutputT> bindInvoker(
+      DoFnWithContext<InputT, OutputT> fn);
 
   private static final Map<Class<?>, DoFnReflector> REFLECTOR_CACHE =
       new LinkedHashMap<Class<?>, DoFnReflector>();
@@ -192,14 +236,15 @@ public abstract class DoFnReflector {
   }
 
   private static Collection<String> describeSupportedTypes(
-      Map<Class<?>, ExtraContextInfo> extraProcessContexts,
+      Map<Class<?>, AdditionalParameter> extraProcessContexts,
       final TypeToken<?> in, final TypeToken<?> out) {
     return FluentIterable
         .from(extraProcessContexts.values())
-        .transform(new Function<ExtraContextInfo, String>() {
+        .transform(new Function<AdditionalParameter, String>() {
+
           @Override
           @Nullable
-          public String apply(@Nullable ExtraContextInfo input) {
+          public String apply(@Nullable AdditionalParameter input) {
             if (input == null) {
               return null;
             } else {
@@ -211,21 +256,22 @@ public abstract class DoFnReflector {
   }
 
   @VisibleForTesting
-  static <InputT, OutputT> ExtraContextInfo[] verifyProcessMethodArguments(Method m) {
+  static <InputT, OutputT> List<AdditionalParameter> verifyProcessMethodArguments(Method m) {
     return verifyMethodArguments(m,
         EXTRA_PROCESS_CONTEXTS,
-        new TypeToken<DoFnWithContext<InputT, OutputT>.ProcessContext>() {
-          },
+        new TypeToken<DoFnWithContext<InputT, OutputT>.ProcessContext>() {},
         new TypeParameter<InputT>() {},
         new TypeParameter<OutputT>() {});
   }
 
   @VisibleForTesting
-  static <InputT, OutputT> ExtraContextInfo[] verifyBundleMethodArguments(Method m) {
+  static <InputT, OutputT> List<AdditionalParameter> verifyBundleMethodArguments(Method m) {
+    if (m == null) {
+      return null;
+    }
     return verifyMethodArguments(m,
         EXTRA_CONTEXTS,
-        new TypeToken<DoFnWithContext<InputT, OutputT>.Context>() {
-          },
+        new TypeToken<DoFnWithContext<InputT, OutputT>.Context>() {},
         new TypeParameter<InputT>() {},
         new TypeParameter<OutputT>() {});
   }
@@ -245,15 +291,18 @@ public abstract class DoFnReflector {
    * </ol>
    *
    * @param m the method to verify
-   * @param contexts mapping from raw classes to the {@link ExtraContextInfo} used
+   * @param contexts mapping from raw classes to the {@link AdditionalParameter} used
    *     to create new instances.
    * @param firstContextArg the expected type of the first context argument
    * @param iParam TypeParameter representing the input type
    * @param oParam TypeParameter representing the output type
    */
-  @VisibleForTesting static <InputT, OutputT> ExtraContextInfo[] verifyMethodArguments(Method m,
-      Map<Class<?>, ExtraContextInfo> contexts,
-      TypeToken<?> firstContextArg, TypeParameter<InputT> iParam, TypeParameter<OutputT> oParam) {
+  @VisibleForTesting static <InputT, OutputT> List<AdditionalParameter> verifyMethodArguments(
+      Method m,
+      Map<Class<?>, AdditionalParameter> contexts,
+      TypeToken<?> firstContextArg,
+      TypeParameter<InputT> iParam,
+      TypeParameter<OutputT> oParam) {
 
     if (!void.class.equals(m.getReturnType())) {
       throw new IllegalStateException(String.format(
@@ -276,7 +325,7 @@ public abstract class DoFnReflector {
           "%s must take a %s as its first argument",
           format(m), firstContextArg.getRawType().getSimpleName()));
     }
-    ExtraContextInfo[] contextInfos = new ExtraContextInfo[params.length - 1];
+    AdditionalParameter[] contextInfos = new AdditionalParameter[params.length - 1];
 
     // Fill in the generics in the allExtraContextArgs interface from the types in the
     // Context or ProcessContext DoFn.
@@ -293,7 +342,7 @@ public abstract class DoFnReflector {
     for (int i = 1; i < params.length; i++) {
       TypeToken<?> param = TypeToken.of(params[i]);
 
-      ExtraContextInfo info = contexts.get(param.getRawType());
+      AdditionalParameter info = contexts.get(param.getRawType());
       if (info == null) {
         throw new IllegalStateException(String.format(
             "%s is not a valid context parameter for method %s. Should be one of %s",
@@ -312,7 +361,24 @@ public abstract class DoFnReflector {
       // Register the (now validated) context info
       contextInfos[i - 1] = info;
     }
-    return contextInfos;
+    return ImmutableList.copyOf(contextInfos);
+  }
+
+  /** Interface for invoking the {@code DoFn} processing methods. */
+  public interface DoFnInvoker<InputT, OutputT>  {
+    /** Invoke {@link DoFn#startBundle} on the bound {@code DoFn}. */
+    void invokeStartBundle(
+        DoFnWithContext<InputT, OutputT>.Context c,
+        ExtraContextFactory<InputT, OutputT> extra);
+    /** Invoke {@link DoFn#finishBundle} on the bound {@code DoFn}. */
+    void invokeFinishBundle(
+        DoFnWithContext<InputT, OutputT>.Context c,
+        ExtraContextFactory<InputT, OutputT> extra);
+
+    /** Invoke {@link DoFn#processElement} on the bound {@code DoFn}. */
+    public void invokeProcessElement(
+        DoFnWithContext<InputT, OutputT>.ProcessContext c,
+        ExtraContextFactory<InputT, OutputT> extra);
   }
 
   /**
@@ -320,27 +386,27 @@ public abstract class DoFnReflector {
    */
   private static class GenericDoFnReflector extends DoFnReflector {
 
-    private Method startBundle;
-    private Method processElement;
-    private Method finishBundle;
-    private ExtraContextInfo[] processElementArgs;
-    private ExtraContextInfo[] startBundleArgs;
-    private ExtraContextInfo[] finishBundleArgs;
+    private final Method startBundle;
+    private final Method processElement;
+    private final Method finishBundle;
+    private final List<AdditionalParameter> processElementArgs;
+    private final List<AdditionalParameter> startBundleArgs;
+    private final List<AdditionalParameter> finishBundleArgs;
+    private final Constructor<?> constructor;
 
-    private GenericDoFnReflector(Class<?> fn) {
+    private GenericDoFnReflector(
+        @SuppressWarnings("rawtypes") Class<? extends DoFnWithContext> fn) {
       // Locate the annotated methods
       this.processElement = findAnnotatedMethod(ProcessElement.class, fn, true);
       this.startBundle = findAnnotatedMethod(StartBundle.class, fn, false);
       this.finishBundle = findAnnotatedMethod(FinishBundle.class, fn, false);
 
       // Verify that their method arguments satisfy our conditions.
-      processElementArgs = verifyProcessMethodArguments(processElement);
-      if (startBundle != null) {
-        startBundleArgs = verifyBundleMethodArguments(startBundle);
-      }
-      if (finishBundle != null) {
-        finishBundleArgs = verifyBundleMethodArguments(finishBundle);
-      }
+      this.processElementArgs = verifyProcessMethodArguments(processElement);
+      this.startBundleArgs = verifyBundleMethodArguments(startBundle);
+      this.finishBundleArgs = verifyBundleMethodArguments(finishBundle);
+
+      this.constructor = createInvokerConstructor(fn);
     }
 
     private static Collection<Method> declaredMethodsWithAnnotation(
@@ -411,74 +477,74 @@ public abstract class DoFnReflector {
         throw new IllegalStateException(format(first) + " must not be static");
       }
 
-      first.setAccessible(true);
       return first;
     }
 
     @Override
     public boolean usesSingleWindow() {
-      return usesContext(BoundedWindow.class);
+      return usesContext(AdditionalParameter.WINDOW_OF_ELEMENT);
     }
 
-    private boolean usesContext(Class<?> context) {
-      for (Class<?> clazz : processElement.getParameterTypes()) {
-        if (clazz.equals(context)) {
-          return true;
-        }
-      }
-      return false;
+    private boolean usesContext(AdditionalParameter param) {
+      return processElementArgs.contains(param)
+          || (startBundleArgs != null && startBundleArgs.contains(param))
+          || (finishBundleArgs != null && finishBundleArgs.contains(param));
     }
 
-    @Override
-    public <InputT, OutputT> void invokeProcessElement(
-        DoFnWithContext<InputT, OutputT> fn,
-        DoFnWithContext<InputT, OutputT>.ProcessContext c,
-        ExtraContextFactory<InputT, OutputT> extra) {
-      invoke(processElement, fn, c, extra, processElementArgs);
-    }
-
-    @Override
-    public <InputT, OutputT> void invokeStartBundle(
-        DoFnWithContext<InputT, OutputT> fn,
-        DoFnWithContext<InputT, OutputT>.Context c,
-        ExtraContextFactory<InputT, OutputT> extra) {
-      super.invokeStartBundle(fn, c, extra);
-      if (startBundle != null) {
-        invoke(startBundle, fn, c, extra, startBundleArgs);
+    /**
+     * Use ByteBuddy to generate the code for a {@link DoFnInvoker} that invokes the given
+     * {@link DoFnWithContext}.
+     * @param clazz
+     * @return
+     */
+    private Constructor<? extends DoFnInvoker<?, ?>> createInvokerConstructor(
+        @SuppressWarnings("rawtypes") Class<? extends DoFnWithContext> clazz) {
+      DynamicType.Builder<?> builder = new ByteBuddy()
+          .subclass(DoFnInvoker.class, ConstructorStrategy.Default.NO_CONSTRUCTORS)
+          .defineField(FN_DELEGATE_FIELD_NAME, clazz, Visibility.PRIVATE, FieldManifestation.FINAL)
+          // Define a constructor to populate fields appropriately.
+          .defineConstructor(Visibility.PUBLIC)
+          .withParameter(clazz)
+          .intercept(new InvokerConstructor())
+          // Implement the three methods by calling into the appropriate functions on the fn.
+          .method(ElementMatchers.named("invokeProcessElement"))
+          .intercept(InvokerDelegation.create(
+              processElement, BeforeDelegation.NOOP, processElementArgs))
+          .method(ElementMatchers.named("invokeStartBundle"))
+          .intercept(InvokerDelegation.create(
+              startBundle, BeforeDelegation.INVOKE_PREPARE_FOR_PROCESSING, startBundleArgs))
+          .method(ElementMatchers.named("invokeFinishBundle"))
+          .intercept(InvokerDelegation.create(
+              finishBundle, BeforeDelegation.NOOP, finishBundleArgs));
+
+      @SuppressWarnings("unchecked")
+      Class<? extends DoFnInvoker<?, ?>> dynamicClass = (Class<? extends DoFnInvoker<?, ?>>) builder
+          .make()
+          .load(getClass().getClassLoader(), ClassLoadingStrategy.Default.INJECTION)
+          .getLoaded();
+      try {
+        return dynamicClass.getConstructor(clazz);
+      } catch (IllegalArgumentException
+          | NoSuchMethodException
+          | SecurityException e) {
+        throw new RuntimeException(e);
       }
     }
 
     @Override
-    public <InputT, OutputT> void invokeFinishBundle(
-        DoFnWithContext<InputT, OutputT> fn,
-        DoFnWithContext<InputT, OutputT>.Context c,
-        ExtraContextFactory<InputT, OutputT> extra) {
-      if (finishBundle != null) {
-        invoke(finishBundle, fn, c, extra, finishBundleArgs);
-      }
-    }
-
-    private <InputT, OutputT> void invoke(Method m,
-        DoFnWithContext<InputT, OutputT> on,
-        DoFnWithContext<InputT, OutputT>.Context contextArg,
-        ExtraContextFactory<InputT, OutputT> extraArgFactory,
-        ExtraContextInfo[] extraArgs) {
-
-      Class<?>[] parameterTypes = m.getParameterTypes();
-      Object[] args = new Object[parameterTypes.length];
-      args[0] = contextArg;
-      for (int i = 1; i < args.length; i++) {
-        args[i] = extraArgs[i - 1].createInstance(extraArgFactory);
-      }
-
+    public <InputT, OutputT> DoFnInvoker<InputT, OutputT> bindInvoker(
+        DoFnWithContext<InputT, OutputT> fn) {
       try {
-        m.invoke(on, args);
-      } catch (InvocationTargetException e) {
-        // Exception in user code.
-        throw UserCodeException.wrap(e.getCause());
-      } catch (IllegalAccessException | IllegalArgumentException e) {
-        // Exception in our code.
-        throw new RuntimeException(e);
+        @SuppressWarnings("unchecked")
+        DoFnInvoker<InputT, OutputT> invoker =
+            (DoFnInvoker<InputT, OutputT>) constructor.newInstance(fn);
+        return invoker;
+      } catch (InstantiationException
+          | IllegalAccessException
+          | IllegalArgumentException
+          | InvocationTargetException
+          | SecurityException e) {
+        throw new RuntimeException("Unable to bind invoker for " + fn.getClass(), e);
       }
     }
   }
@@ -615,32 +681,31 @@ public abstract class DoFnReflector {
 
   private static class SimpleDoFnAdapter<InputT, OutputT> extends DoFn<InputT, OutputT> {
 
-    private transient DoFnReflector reflector;
-    private DoFnWithContext<InputT, OutputT> fn;
+    private final DoFnWithContext<InputT, OutputT> fn;
+    private transient DoFnInvoker<InputT, OutputT> invoker;
 
     private SimpleDoFnAdapter(DoFnReflector reflector, DoFnWithContext<InputT, OutputT> fn) {
       super(fn.aggregators);
-      this.reflector = reflector;
       this.fn = fn;
+      this.invoker = reflector.bindInvoker(fn);
     }
 
     @Override
     public void startBundle(DoFn<InputT, OutputT>.Context c) throws Exception {
       ContextAdapter<InputT, OutputT> adapter = new ContextAdapter<>(fn, c);
-      reflector.invokeStartBundle(fn, (DoFnWithContext<InputT, OutputT>.Context) adapter, adapter);
+      invoker.invokeStartBundle(adapter, adapter);
     }
 
     @Override
     public void finishBundle(DoFn<InputT, OutputT>.Context c) throws Exception {
       ContextAdapter<InputT, OutputT> adapter = new ContextAdapter<>(fn, c);
-      reflector.invokeFinishBundle(fn, (DoFnWithContext<InputT, OutputT>.Context) adapter, adapter);
+      invoker.invokeFinishBundle(adapter, adapter);
     }
 
     @Override
     public void processElement(DoFn<InputT, OutputT>.ProcessContext c) throws Exception {
       ProcessContextAdapter<InputT, OutputT> adapter = new ProcessContextAdapter<>(fn, c);
-      reflector.invokeProcessElement(
-          fn, (DoFnWithContext<InputT, OutputT>.ProcessContext) adapter, adapter);
+      invoker.invokeProcessElement(adapter, adapter);
     }
 
     @Override
@@ -661,7 +726,7 @@ public abstract class DoFnReflector {
     private void readObject(java.io.ObjectInputStream in)
         throws IOException, ClassNotFoundException {
       in.defaultReadObject();
-      reflector = DoFnReflector.of(fn.getClass());
+      invoker = DoFnReflector.of(fn.getClass()).bindInvoker(fn);
     }
   }
 
@@ -672,4 +737,316 @@ public abstract class DoFnReflector {
       super(reflector, fn);
     }
   }
+
+  private static enum BeforeDelegation {
+    NOOP {
+      @Override
+      StackManipulation manipulation(
+          TypeDescription delegateType, MethodDescription instrumentedMethod, boolean finalStep) {
+        Preconditions.checkArgument(!finalStep,
+            "Shouldn't use NOOP delegation if there is nothing to do afterwards.");
+        return StackManipulation.Trivial.INSTANCE;
+      }
+    },
+    INVOKE_PREPARE_FOR_PROCESSING {
+      private final Assigner assigner = Assigner.DEFAULT;
+
+      @Override
+      StackManipulation manipulation(
+          TypeDescription delegateType, MethodDescription instrumentedMethod, boolean finalStep) {
+        MethodDescription prepareMethod;
+        try {
+          prepareMethod = new MethodLocator.ForExplicitMethod(
+              new MethodDescription.ForLoadedMethod(
+                  DoFnWithContext.class.getDeclaredMethod("prepareForProcessing")))
+          .resolve(instrumentedMethod);
+        } catch (NoSuchMethodException | SecurityException e) {
+          throw new RuntimeException("Unable to locate prepareForProcessing method", e);
+        }
+
+        if (finalStep) {
+          return new StackManipulation.Compound(
+              // Invoke the prepare method
+              MethodInvoker.Simple.INSTANCE.invoke(prepareMethod),
+              // Return from the invokeStartBundle when we're done.
+              TerminationHandler.Returning.INSTANCE.resolve(
+                  assigner, instrumentedMethod, prepareMethod));
+        } else {
+          return new StackManipulation.Compound(
+              // Duplicate the delegation target so that it remains after we invoke prepare
+              Duplication.duplicate(delegateType),
+              // Invoke the prepare method
+              MethodInvoker.Simple.INSTANCE.invoke(prepareMethod),
+              // Drop the return value from prepareForProcessing
+              TerminationHandler.Dropping.INSTANCE.resolve(
+                  assigner, instrumentedMethod, prepareMethod));
+        }
+      }
+    };
+
+    /**
+     * Stack manipulation to perform prior to the delegate call.
+     *
+     * <ul>
+     * <li>Precondition: Stack has the delegate target on top of the stack
+     * <li>Postcondition: If finalStep is true, then we've returned from the method. Otherwise, the
+     * stack still has the delegate target on top of the stack.
+     * </ul>
+     *
+     * @param delegateType The type of the delegate target, in case it needs to be duplicated.
+     * @param instrumentedMethod The method bing instrumented. Necessary for resolving types and
+     *     other information.
+     * @param finalStep If true, return from the {@code invokeStartBundle} method after invoking
+     * {@code prepareForProcessing} on the delegate.
+     */
+    abstract StackManipulation manipulation(
+        TypeDescription delegateType, MethodDescription instrumentedMethod, boolean finalStep);
+  }
+
+  /**
+   * A byte-buddy {@link Implementation} that delegates a call that receives
+   * {@link AdditionalParameter} to the given {@link DoFnWithContext} method.
+   */
+  private static final class InvokerDelegation implements Implementation {
+    @Nullable
+    private final Method target;
+    private final BeforeDelegation before;
+    private final List<AdditionalParameter> args;
+    private final Assigner assigner = Assigner.DEFAULT;
+    private FieldDescription field;
+
+    /**
+     * Create the {@link InvokerDelegation} for the specified method.
+     *
+     * @param target the method to delegate to
+     * @param isStartBundle whether or not this is the {@code startBundle} call
+     * @param args the {@link AdditionalParameter} to be passed to the {@code target}
+     */
+    private InvokerDelegation(
+        @Nullable Method target,
+        BeforeDelegation before,
+        List<AdditionalParameter> args) {
+      this.target = target;
+      this.before = before;
+      this.args = args;
+    }
+
+    /**
+     * Generate the {@link Implementation} of one of the life-cycle methods of a
+     * {@link DoFnWithContext}.
+     */
+    private static Implementation create(
+        @Nullable final Method target, BeforeDelegation before, List<AdditionalParameter> args) {
+      if (target == null && before == BeforeDelegation.NOOP) {
+        // There is no target to call and nothing needs to happen before. Just produce a stub.
+        return StubMethod.INSTANCE;
+      } else {
+        // We need to generate a non-empty method implementation.
+        return new InvokerDelegation(target, before, args);
+      }
+    }
+
+    @Override
+    public InstrumentedType prepare(InstrumentedType instrumentedType) {
+      // Remember the field description of the instrumented type.
+      field = instrumentedType.getDeclaredFields()
+          .filter(ElementMatchers.named(FN_DELEGATE_FIELD_NAME)).getOnly();
+
+      // Delegating the method call doesn't require any changes to the instrumented type.
+      return instrumentedType;
+    }
+
+    /**
+     * Stack manipulation to push the {@link DoFnWithContext} reference stored in the
+     * delegate field of the invoker on to the top of the stack.
+     *
+     * <p>This implementation is derived from the code for
+     * {@code MethodCall.invoke(m).onInstanceField(clazz, delegateField)} with two key differences.
+     * First, it doesn't add a synthetic field each time, which is critical to avoid duplicate field
+     * definitions. Second, it uses the {@link AdditionalParameter} to populate the arguments to the
+     * method.
+     */
+    private StackManipulation pushDelegateField() {
+      return new StackManipulation.Compound(
+          // Push "this" reference to the stack
+          MethodVariableAccess.REFERENCE.loadOffset(0),
+          // Access the delegate field of the the invoker
+          FieldAccess.forField(field).getter());
+    }
+
+    private StackManipulation pushArgument(
+        AdditionalParameter arg, MethodDescription instrumentedMethod) {
+      MethodDescription transform = arg.method;
+
+      return new StackManipulation.Compound(
+          // Push the ExtraContextFactory which must have been argument 2 of the instrumented method
+          MethodVariableAccess.REFERENCE.loadOffset(2),
+          // Invoke the appropriate method to produce the context argument
+          MethodInvocation.invoke(transform));
+    }
+
+    private StackManipulation invokeTargetMethod(MethodDescription instrumentedMethod) {
+      MethodDescription targetMethod = new MethodLocator.ForExplicitMethod(
+          new MethodDescription.ForLoadedMethod(target)).resolve(instrumentedMethod);
+      ParameterList<?> params = targetMethod.getParameters();
+
+      // Instructions to setup the parameters for the call
+      ArrayList<StackManipulation> parameters = new ArrayList<>(args.size() + 1);
+      // 1. The first argument in the delegate method must be the context. This corresponds to
+      //    the first argument in the instrumented method, so copy that.
+      parameters.add(MethodVariableAccess.of(
+          params.get(0).getType().getSuperClass()).loadOffset(1));
+      // 2. For each of the extra arguments push the appropriate value.
+      for (AdditionalParameter arg : args) {
+        parameters.add(pushArgument(arg, instrumentedMethod));
+      }
+
+      return new StackManipulation.Compound(
+          // Push the parameters
+          new StackManipulation.Compound(parameters),
+          // Invoke the target method
+          wrapWithUserCodeException(MethodInvoker.Simple.INSTANCE.invoke(targetMethod)),
+          // Return from the instrumented method
+          TerminationHandler.Returning.INSTANCE.resolve(
+              assigner, instrumentedMethod, targetMethod));
+    }
+
+    /**
+     * Wrap a given stack manipulation in a try catch block. Any exceptions thrown within the
+     * try are wrapped with a {@link UserCodeException}.
+     */
+    private StackManipulation wrapWithUserCodeException(
+        final StackManipulation tryBody) {
+      final MethodDescription createUserCodeException;
+      try {
+        createUserCodeException = new MethodDescription.ForLoadedMethod(
+                UserCodeException.class.getDeclaredMethod("wrap", Throwable.class));
+      } catch (NoSuchMethodException | SecurityException e) {
+        throw new RuntimeException("Unable to find UserCodeException.wrap", e);
+      }
+
+      return new StackManipulation() {
+        @Override
+        public boolean isValid() {
+          return tryBody.isValid();
+        }
+
+        @Override
+        public Size apply(MethodVisitor mv, Context implementationContext) {
+          Label tryBlockStart = new Label();
+          Label tryBlockEnd = new Label();
+          Label catchBlockStart = new Label();
+          Label catchBlockEnd = new Label();
+
+          String throwableName =
+              new TypeDescription.ForLoadedType(Throwable.class).getInternalName();
+          mv.visitTryCatchBlock(tryBlockStart, tryBlockEnd, catchBlockStart, throwableName);
+
+          // The try block attempts to perform the expected operations, then jumps to success
+          mv.visitLabel(tryBlockStart);
+          Size trySize = tryBody.apply(mv, implementationContext);
+          mv.visitJumpInsn(Opcodes.GOTO, catchBlockEnd);
+          mv.visitLabel(tryBlockEnd);
+
+          // The handler wraps the exception, and then throws.
+          mv.visitLabel(catchBlockStart);
+          // Add the exception to the frame
+          mv.visitFrame(Opcodes.F_SAME1,
+              // No local variables
+              0, new Object[] {},
+              // 1 stack element (the throwable)
+              1, new Object[] { throwableName });
+
+          Size catchSize = new StackManipulation.Compound(
+              MethodInvocation.invoke(createUserCodeException),
+              Throw.INSTANCE)
+              .apply(mv, implementationContext);
+
+          mv.visitLabel(catchBlockEnd);
+          // The frame contents after the try/catch block is the same
+          // as it was before.
+          mv.visitFrame(Opcodes.F_SAME,
+              // No local variables
+              0, new Object[] {},
+              // No new stack variables
+              0, new Object[] {});
+
+          return new Size(
+              trySize.getSizeImpact(),
+              Math.max(trySize.getMaximalSize(), catchSize.getMaximalSize()));
+        }
+      };
+    }
+
+    @Override
+    public ByteCodeAppender appender(final Target implementationTarget) {
+      return new ByteCodeAppender() {
+        @Override
+        public Size apply(
+            MethodVisitor methodVisitor,
+            Context implementationContext,
+            MethodDescription instrumentedMethod) {
+          StackManipulation.Size size = new StackManipulation.Compound(
+              // Put the target on top of the stack
+              pushDelegateField(),
+              // Do any necessary pre-delegation work
+              before.manipulation(field.getType().asErasure(), instrumentedMethod, target == null),
+              // Invoke the target method, if there is one. If there wasn't, then isStartBundle was
+              // true, and we've already emitted the appropriate return instructions.
+              target != null
+                  ? invokeTargetMethod(instrumentedMethod)
+                  : StackManipulation.Trivial.INSTANCE)
+              .apply(methodVisitor, implementationContext);
+          return new Size(size.getMaximalSize(), instrumentedMethod.getStackSize());
+        }
+      };
+    }
+  }
+
+  /**
+   * A constructor {@link Implementation} for a {@link DoFnInvoker class}. Produces the byte code
+   * for a constructor that takes a single argument and assigns it to the delegate field.
+   * {@link AdditionalParameter} to the given {@link DoFnWithContext} method.
+   */
+  private static final class InvokerConstructor implements Implementation {
+    @Override
+    public InstrumentedType prepare(InstrumentedType instrumentedType) {
+      return instrumentedType;
+    }
+
+    @Override
+    public ByteCodeAppender appender(final Target implementationTarget) {
+      return new ByteCodeAppender() {
+        @Override
+        public Size apply(
+            MethodVisitor methodVisitor,
+            Context implementationContext,
+            MethodDescription instrumentedMethod) {
+          StackManipulation.Size size = new StackManipulation.Compound(
+              // Load the this reference
+              MethodVariableAccess.REFERENCE.loadOffset(0),
+              // Invoke the super constructor (default constructor of Object)
+              MethodInvocation
+                  .invoke(new TypeDescription.ForLoadedType(Object.class)
+                    .getDeclaredMethods()
+                    .filter(ElementMatchers.isConstructor()
+                      .and(ElementMatchers.takesArguments(0)))
+                    .getOnly()),
+              // Load the this reference
+              MethodVariableAccess.REFERENCE.loadOffset(0),
+              // Load the delegate argument
+              MethodVariableAccess.REFERENCE.loadOffset(1),
+              // Assign the delegate argument to the delegate field
+              FieldAccess.forField(implementationTarget.getInstrumentedType()
+                  .getDeclaredFields()
+                  .filter(ElementMatchers.named(FN_DELEGATE_FIELD_NAME))
+                  .getOnly()).putter(),
+              // Return void.
+              MethodReturn.VOID
+            ).apply(methodVisitor, implementationContext);
+            return new Size(size.getMaximalSize(), instrumentedMethod.getStackSize());
+        }
+      };
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b47919c/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
index 7399322..cf9f8e8 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnReflectorTest.java
@@ -26,6 +26,7 @@ import org.apache.beam.sdk.transforms.DoFnWithContext.ExtraContextFactory;
 import org.apache.beam.sdk.transforms.DoFnWithContext.ProcessContext;
 import org.apache.beam.sdk.transforms.DoFnWithContext.ProcessElement;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.util.WindowingInternals;
 
 import org.junit.Before;
@@ -45,9 +46,16 @@ import java.lang.reflect.Method;
 @RunWith(JUnit4.class)
 public class DoFnReflectorTest {
 
-  private boolean wasProcessElementInvoked = false;
-  private boolean wasStartBundleInvoked = false;
-  private boolean wasFinishBundleInvoked = false;
+  private static class Invocations {
+    private boolean wasProcessElementInvoked = false;
+    private boolean wasStartBundleInvoked = false;
+    private boolean wasFinishBundleInvoked = false;
+    private final String name;
+
+    public Invocations(String name) {
+      this.name = name;
+    }
+  }
 
   private DoFnWithContext<String, String> fn;
 
@@ -84,39 +92,78 @@ public class DoFnReflectorTest {
     return DoFnReflector.of(fn.getClass());
   }
 
-  private void checkInvokeProcessElementWorks(DoFnReflector r) throws Exception {
-    assertFalse(wasProcessElementInvoked);
-    r.invokeProcessElement(fn, mockContext, extraContextFactory);
-    assertTrue(wasProcessElementInvoked);
+  private void checkInvokeProcessElementWorks(
+      DoFnReflector r, Invocations... invocations) throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called processElement on " + invocation.name,
+          invocation.wasProcessElementInvoked);
+    }
+    r.bindInvoker(fn).invokeProcessElement(mockContext, extraContextFactory);
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called processElement on " + invocation.name,
+          invocation.wasProcessElementInvoked);
+    }
   }
 
-  private void checkInvokeStartBundleWorks(DoFnReflector r) throws Exception {
-    assertFalse(wasStartBundleInvoked);
-    r.invokeStartBundle(fn, mockContext, extraContextFactory);
-    assertTrue(wasStartBundleInvoked);
+  private void checkInvokeStartBundleWorks(
+      DoFnReflector r, Invocations... invocations) throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called startBundle on " + invocation.name,
+          invocation.wasStartBundleInvoked);
+    }
+    r.bindInvoker(fn).invokeStartBundle(mockContext, extraContextFactory);
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called startBundle on " + invocation.name,
+          invocation.wasStartBundleInvoked);
+    }
   }
 
-  private void checkInvokeFinishBundleWorks(DoFnReflector r) throws Exception {
-    assertFalse(wasFinishBundleInvoked);
-    r.invokeFinishBundle(fn, mockContext, extraContextFactory);
-    assertTrue(wasFinishBundleInvoked);
+  private void checkInvokeFinishBundleWorks(
+      DoFnReflector r, Invocations... invocations) throws Exception {
+    assertTrue("Need at least one invocation to check", invocations.length >= 1);
+    for (Invocations invocation : invocations) {
+      assertFalse("Should not yet have called finishBundle on " + invocation.name,
+          invocation.wasFinishBundleInvoked);
+    }
+    r.bindInvoker(fn).invokeFinishBundle(mockContext, extraContextFactory);
+    for (Invocations invocation : invocations) {
+      assertTrue("Should have called finishBundle on " + invocation.name,
+          invocation.wasFinishBundleInvoked);
+    }
   }
 
   @Test
   public void testDoFnWithNoExtraContext() throws Exception {
+    final Invocations invocations = new Invocations("AnonymousClass");
     DoFnReflector reflector = underTest(new DoFnWithContext<String, String>() {
 
       @ProcessElement
       public void processElement(ProcessContext c)
           throws Exception {
-        wasProcessElementInvoked = true;
+        invocations.wasProcessElementInvoked = true;
         assertSame(c, mockContext);
       }
     });
 
     assertFalse(reflector.usesSingleWindow());
 
-    checkInvokeProcessElementWorks(reflector);
+    checkInvokeProcessElementWorks(reflector, invocations);
+  }
+
+  @Test
+  public void testDoFnInvokersReused() throws Exception {
+    // Ensures that we don't create a new Invoker class for every instance of the DoFn.
+    IdentityParent fn1 = new IdentityParent();
+    IdentityParent fn2 = new IdentityParent();
+    DoFnReflector reflector1 = underTest(fn1);
+    DoFnReflector reflector2 = underTest(fn2);
+    assertSame("DoFnReflector instances should be cached and reused for identical types",
+        reflector1, reflector2);
+    assertSame("Invoker classes should only be generated once for each type",
+        reflector1.bindInvoker(fn1).getClass(),
+        reflector2.bindInvoker(fn2).getClass());
   }
 
   interface InterfaceWithProcessElement {
@@ -130,45 +177,71 @@ public class DoFnReflectorTest {
       extends DoFnWithContext<String, String>
       implements LayersOfInterfaces {
 
+    private Invocations invocations = new Invocations("Named Class");
+
     @Override
     public void processElement(DoFnWithContext<String, String>.ProcessContext c) {
-      wasProcessElementInvoked = true;
+      invocations.wasProcessElementInvoked = true;
       assertSame(c, mockContext);
     }
   }
 
   @Test
   public void testDoFnWithProcessElementInterface() throws Exception {
-    DoFnReflector reflector = underTest(new IdentityUsingInterfaceWithProcessElement());
+    IdentityUsingInterfaceWithProcessElement fn = new IdentityUsingInterfaceWithProcessElement();
+    DoFnReflector reflector = underTest(fn);
     assertFalse(reflector.usesSingleWindow());
-    checkInvokeProcessElementWorks(reflector);
+    checkInvokeProcessElementWorks(reflector, fn.invocations);
   }
 
   private class IdentityParent extends DoFnWithContext<String, String> {
+    protected Invocations parentInvocations = new Invocations("IdentityParent");
+
     @ProcessElement
     public void process(ProcessContext c) {
-      wasProcessElementInvoked = true;
+      parentInvocations.wasProcessElementInvoked = true;
       assertSame(c, mockContext);
     }
   }
 
-  private class IdentityChild extends IdentityParent {}
+  private class IdentityChildWithoutOverride extends IdentityParent {
+  }
+
+  private class IdentityChildWithOverride extends IdentityParent {
+    protected Invocations childInvocations = new Invocations("IdentityChildWithOverride");
+
+    @Override
+    public void process(DoFnWithContext<String, String>.ProcessContext c) {
+      super.process(c);
+      childInvocations.wasProcessElementInvoked = true;
+    }
+  }
 
   @Test
   public void testDoFnWithMethodInSuperclass() throws Exception {
-    DoFnReflector reflector = underTest(new IdentityChild());
+    IdentityChildWithoutOverride fn = new IdentityChildWithoutOverride();
+    DoFnReflector reflector = underTest(fn);
     assertFalse(reflector.usesSingleWindow());
-    checkInvokeProcessElementWorks(reflector);
+    checkInvokeProcessElementWorks(reflector, fn.parentInvocations);
+  }
+
+  @Test
+  public void testDoFnWithMethodInSubclass() throws Exception {
+    IdentityChildWithOverride fn = new IdentityChildWithOverride();
+    DoFnReflector reflector = underTest(fn);
+    assertFalse(reflector.usesSingleWindow());
+    checkInvokeProcessElementWorks(reflector, fn.parentInvocations, fn.childInvocations);
   }
 
   @Test
   public void testDoFnWithWindow() throws Exception {
+    final Invocations invocations = new Invocations("AnonymousClass");
     DoFnReflector reflector = underTest(new DoFnWithContext<String, String>() {
 
       @ProcessElement
       public void processElement(ProcessContext c, BoundedWindow w)
           throws Exception {
-        wasProcessElementInvoked = true;
+        invocations.wasProcessElementInvoked = true;
         assertSame(c, mockContext);
         assertSame(w, mockWindow);
       }
@@ -176,17 +249,18 @@ public class DoFnReflectorTest {
 
     assertTrue(reflector.usesSingleWindow());
 
-    checkInvokeProcessElementWorks(reflector);
+    checkInvokeProcessElementWorks(reflector, invocations);
   }
 
   @Test
   public void testDoFnWithWindowingInternals() throws Exception {
+    final Invocations invocations = new Invocations("AnonymousClass");
     DoFnReflector reflector = underTest(new DoFnWithContext<String, String>() {
 
       @ProcessElement
       public void processElement(ProcessContext c, WindowingInternals<String, String> w)
           throws Exception {
-        wasProcessElementInvoked = true;
+        invocations.wasProcessElementInvoked = true;
         assertSame(c, mockContext);
         assertSame(w, mockWindowingInternals);
       }
@@ -194,30 +268,31 @@ public class DoFnReflectorTest {
 
     assertFalse(reflector.usesSingleWindow());
 
-    checkInvokeProcessElementWorks(reflector);
+    checkInvokeProcessElementWorks(reflector, invocations);
   }
 
   @Test
   public void testDoFnWithStartBundle() throws Exception {
+    final Invocations invocations = new Invocations("AnonymousClass");
     DoFnReflector reflector = underTest(new DoFnWithContext<String, String>() {
       @ProcessElement
       public void processElement(@SuppressWarnings("unused") ProcessContext c) {}
 
       @StartBundle
       public void startBundle(Context c) {
-        wasStartBundleInvoked = true;
+        invocations.wasStartBundleInvoked = true;
         assertSame(c, mockContext);
       }
 
       @FinishBundle
       public void finishBundle(Context c) {
-        wasFinishBundleInvoked = true;
+        invocations.wasFinishBundleInvoked = true;
         assertSame(c, mockContext);
       }
     });
 
-    checkInvokeStartBundleWorks(reflector);
-    checkInvokeFinishBundleWorks(reflector);
+    checkInvokeStartBundleWorks(reflector, invocations);
+    checkInvokeFinishBundleWorks(reflector, invocations);
   }
 
   @Test
@@ -321,7 +396,7 @@ public class DoFnReflectorTest {
     });
   }
 
-  @SuppressWarnings({"unused", "rawtypes"})
+  @SuppressWarnings({"unused"})
   private void missingProcessContext() {}
 
   @Test
@@ -334,7 +409,7 @@ public class DoFnReflectorTest {
         getClass().getDeclaredMethod("missingProcessContext"));
   }
 
-  @SuppressWarnings({"unused", "rawtypes"})
+  @SuppressWarnings({"unused"})
   private void badProcessContext(String s) {}
 
   @Test
@@ -347,7 +422,7 @@ public class DoFnReflectorTest {
         getClass().getDeclaredMethod("badProcessContext", String.class));
   }
 
-  @SuppressWarnings({"unused", "rawtypes"})
+  @SuppressWarnings({"unused"})
   private void badExtraContext(DoFnWithContext<Integer, String>.Context c, int n) {}
 
   @Test
@@ -361,7 +436,7 @@ public class DoFnReflectorTest {
         getClass().getDeclaredMethod("badExtraContext", Context.class, int.class));
   }
 
-  @SuppressWarnings({"unused", "rawtypes"})
+  @SuppressWarnings({"unused"})
   private void badExtraProcessContext(
       DoFnWithContext<Integer, String>.ProcessContext c, Integer n) {}
 
@@ -491,4 +566,54 @@ public class DoFnReflectorTest {
 
     DoFnReflector.verifyProcessMethodArguments(method);
   }
+
+  @Test
+  public void testProcessElementException() throws Exception {
+    DoFnWithContext<Integer, Integer> fn = new DoFnWithContext<Integer, Integer>() {
+      @ProcessElement
+      public void processElement(@SuppressWarnings("unused") ProcessContext c) {
+        throw new IllegalArgumentException("bogus");
+      }
+    };
+
+    thrown.expect(UserCodeException.class);
+    thrown.expectMessage("bogus");
+    DoFnReflector.of(fn.getClass()).bindInvoker(fn).invokeProcessElement(null, null);
+  }
+
+  @Test
+  public void testStartBundleException() throws Exception {
+    DoFnWithContext<Integer, Integer> fn = new DoFnWithContext<Integer, Integer>() {
+      @StartBundle
+      public void startBundle(@SuppressWarnings("unused") Context c) {
+        throw new IllegalArgumentException("bogus");
+      }
+
+      @ProcessElement
+      public void processElement(@SuppressWarnings("unused") ProcessContext c) {
+      }
+    };
+
+    thrown.expect(UserCodeException.class);
+    thrown.expectMessage("bogus");
+    DoFnReflector.of(fn.getClass()).bindInvoker(fn).invokeStartBundle(null, null);
+  }
+
+  @Test
+  public void testFinishBundleException() throws Exception {
+    DoFnWithContext<Integer, Integer> fn = new DoFnWithContext<Integer, Integer>() {
+      @FinishBundle
+      public void finishBundle(@SuppressWarnings("unused") Context c) {
+        throw new IllegalArgumentException("bogus");
+      }
+
+      @ProcessElement
+      public void processElement(@SuppressWarnings("unused") ProcessContext c) {
+      }
+    };
+
+    thrown.expect(UserCodeException.class);
+    thrown.expectMessage("bogus");
+    DoFnReflector.of(fn.getClass()).bindInvoker(fn).invokeFinishBundle(null, null);
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2b47919c/sdks/java/microbenchmarks/src/main/java/org/apache/beam/sdk/microbenchmarks/transforms/DoFnReflectorBenchmark.java
----------------------------------------------------------------------
diff --git a/sdks/java/microbenchmarks/src/main/java/org/apache/beam/sdk/microbenchmarks/transforms/DoFnReflectorBenchmark.java b/sdks/java/microbenchmarks/src/main/java/org/apache/beam/sdk/microbenchmarks/transforms/DoFnReflectorBenchmark.java
index 1b8ec2a..f1dfbb9 100644
--- a/sdks/java/microbenchmarks/src/main/java/org/apache/beam/sdk/microbenchmarks/transforms/DoFnReflectorBenchmark.java
+++ b/sdks/java/microbenchmarks/src/main/java/org/apache/beam/sdk/microbenchmarks/transforms/DoFnReflectorBenchmark.java
@@ -22,6 +22,7 @@ import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnReflector;
+import org.apache.beam.sdk.transforms.DoFnReflector.DoFnInvoker;
 import org.apache.beam.sdk.transforms.DoFnWithContext;
 import org.apache.beam.sdk.transforms.DoFnWithContext.ExtraContextFactory;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -72,10 +73,13 @@ public class DoFnReflectorBenchmark {
   private DoFnReflector doFnReflector;
   private DoFn<String, String> adaptedDoFnWithContext;
 
+  private DoFnInvoker<String, String> invoker;
+
   @Setup
   public void setUp() {
     doFnReflector = DoFnReflector.of(doFnWithContext.getClass());
     adaptedDoFnWithContext = doFnReflector.toDoFn(doFnWithContext);
+    invoker = doFnReflector.bindInvoker(doFnWithContext);
   }
 
   @Benchmark
@@ -92,8 +96,7 @@ public class DoFnReflectorBenchmark {
 
   @Benchmark
   public String invokeDoFnWithContext() throws Exception {
-    doFnReflector.invokeProcessElement(
-        doFnWithContext, stubDoFnWithContextContext, extraContextFactory);
+    invoker.invokeProcessElement(stubDoFnWithContextContext, extraContextFactory);
     return stubDoFnWithContextContext.output;
   }
 


Mime
View raw message