beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From t..@apache.org
Subject [24/50] incubator-beam git commit: Refactor and reuse parameter analysis in DoFnSignatures
Date Mon, 07 Nov 2016 19:59:14 GMT
Refactor and reuse parameter analysis in DoFnSignatures


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/8bf6d92c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/8bf6d92c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/8bf6d92c

Branch: refs/heads/apex-runner
Commit: 8bf6d92cf35d11f4f3b02dae677a4fe778d34a61
Parents: 71fa7cd
Author: Kenneth Knowles <klk@google.com>
Authored: Mon Oct 31 21:30:40 2016 -0700
Committer: Kenneth Knowles <klk@google.com>
Committed: Thu Nov 3 21:32:53 2016 -0700

----------------------------------------------------------------------
 .../sdk/transforms/reflect/DoFnSignature.java   |  21 +-
 .../sdk/transforms/reflect/DoFnSignatures.java  | 585 ++++++++++++-------
 .../DoFnSignaturesProcessElementTest.java       |  18 +-
 .../DoFnSignaturesSplittableDoFnTest.java       |   1 -
 .../transforms/reflect/DoFnSignaturesTest.java  |  35 +-
 .../reflect/DoFnSignaturesTestUtils.java        |   5 +-
 6 files changed, 446 insertions(+), 219 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8bf6d92c/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 431de02..7087efa 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -126,12 +126,25 @@ public abstract class DoFnSignature {
     abstract DoFnSignature build();
   }
 
-  /** A method delegated to a annotated method of an underlying {@link DoFn}. */
+  /** A method delegated to an annotated method of an underlying {@link DoFn}. */
   public interface DoFnMethod {
     /** The annotated method itself. */
     Method targetMethod();
   }
 
+  /**
+   * A method delegated to an annotated method of an underlying {@link DoFn} that accepts
a dynamic
+   * list of parameters.
+   */
+  public interface MethodWithExtraParameters extends DoFnMethod {
+    /**
+     * Types of optional parameters of the annotated method, in the order they appear.
+     *
+     * <p>Validation that these are allowed is external to this class.
+     */
+    List<Parameter> extraParameters();
+  }
+
   /** A descriptor for an optional parameter of the {@link DoFn.ProcessElement} method. */
   public abstract static class Parameter {
 
@@ -331,12 +344,13 @@ public abstract class DoFnSignature {
 
   /** Describes a {@link DoFn.ProcessElement} method. */
   @AutoValue
-  public abstract static class ProcessElementMethod implements DoFnMethod {
+  public abstract static class ProcessElementMethod implements MethodWithExtraParameters
{
     /** The annotated method itself. */
     @Override
     public abstract Method targetMethod();
 
     /** Types of optional parameters of the annotated method, in the order they appear. */
+    @Override
     public abstract List<Parameter> extraParameters();
 
     /** Concrete type of the {@link RestrictionTracker} parameter, if present. */
@@ -380,7 +394,7 @@ public abstract class DoFnSignature {
 
   /** Describes a {@link DoFn.OnTimer} method. */
   @AutoValue
-  public abstract static class OnTimerMethod implements DoFnMethod {
+  public abstract static class OnTimerMethod implements MethodWithExtraParameters {
 
     /** The id on the method's {@link DoFn.TimerId} annotation. */
     public abstract String id();
@@ -390,6 +404,7 @@ public abstract class DoFnSignature {
     public abstract Method targetMethod();
 
     /** Types of optional parameters of the annotated method, in the order they appear. */
+    @Override
     public abstract List<Parameter> extraParameters();
 
     static OnTimerMethod create(Method targetMethod, String id, List<Parameter> extraParameters)
{

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8bf6d92c/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index c690ace..0475404 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -19,6 +19,7 @@ package org.apache.beam.sdk.transforms.reflect;
 
 import static com.google.common.base.Preconditions.checkState;
 
+import com.google.auto.value.AutoValue;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
@@ -41,9 +42,11 @@ import java.util.Map;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.StateId;
 import org.apache.beam.sdk.transforms.DoFn.TimerId;
-import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
@@ -81,13 +84,134 @@ public class DoFnSignatures {
     return signature;
   }
 
+  /**
+   * The context for a {@link DoFn} class, for use in analysis.
+   *
+   * <p>It contains much of the information that eventually becomes part of the {@link
+   * DoFnSignature}, but in an intermediate state.
+   */
+  @VisibleForTesting
+  static class FnAnalysisContext {
+
+    private final Map<String, StateDeclaration> stateDeclarations = new HashMap<>();
+    private final Map<String, TimerDeclaration> timerDeclarations = new HashMap<>();
+
+    private FnAnalysisContext() {}
+
+    /** Create an empty context, with no declarations. */
+    public static FnAnalysisContext create() {
+      return new FnAnalysisContext();
+    }
+
+    /** State parameters declared in this context, keyed by {@link StateId}. Unmodifiable.
*/
+    public Map<String, StateDeclaration> getStateDeclarations() {
+      return Collections.unmodifiableMap(stateDeclarations);
+    }
+
+    /** Timer parameters declared in this context, keyed by {@link TimerId}. Unmodifiable.
*/
+    public Map<String, TimerDeclaration> getTimerDeclarations() {
+      return Collections.unmodifiableMap(timerDeclarations);
+    }
+
+    public void addStateDeclaration(StateDeclaration decl) {
+      stateDeclarations.put(decl.id(), decl);
+    }
+
+    public void addStateDeclarations(Iterable<StateDeclaration> decls) {
+      for (StateDeclaration decl : decls) {
+        addStateDeclaration(decl);
+      }
+    }
+
+    public void addTimerDeclaration(TimerDeclaration decl) {
+      timerDeclarations.put(decl.id(), decl);
+    }
+
+    public void addTimerDeclarations(Iterable<TimerDeclaration> decls) {
+      for (TimerDeclaration decl : decls) {
+        addTimerDeclaration(decl);
+      }
+    }
+  }
+
+  /**
+   * The context of analysis within a particular method.
+   *
+   * <p>It contains much of the information that eventually becomes part of the {@link
+   * DoFnSignature.MethodWithExtraParameters}, but in an intermediate state.
+   */
+  private static class MethodAnalysisContext {
+
+    private final Map<String, StateParameter> stateParameters = new HashMap<>();
+    private final Map<String, TimerParameter> timerParameters = new HashMap<>();
+    private final List<Parameter> extraParameters = new ArrayList<>();
+
+    private MethodAnalysisContext() {}
+
+    /** State parameters declared in this context, keyed by {@link StateId}. */
+    public Map<String, StateParameter> getStateParameters() {
+      return Collections.unmodifiableMap(stateParameters);
+    }
+
+    /** Timer parameters declared in this context, keyed by {@link TimerId}. */
+    public Map<String, TimerParameter> getTimerParameters() {
+      return Collections.unmodifiableMap(timerParameters);
+    }
+
+    /** Extra parameters in their entirety. Unmodifiable. */
+    public List<Parameter> getExtraParameters() {
+      return Collections.unmodifiableList(extraParameters);
+    }
+
+    /**
+     * Returns an {@link MethodAnalysisContext} like this one but including the provided
{@link
+     * StateParameter}.
+     */
+    public void addParameter(Parameter param) {
+      extraParameters.add(param);
+
+      if (param instanceof StateParameter) {
+        StateParameter stateParameter = (StateParameter) param;
+        stateParameters.put(stateParameter.referent().id(), stateParameter);
+      }
+      if (param instanceof TimerParameter) {
+        TimerParameter timerParameter = (TimerParameter) param;
+        timerParameters.put(timerParameter.referent().id(), timerParameter);
+      }
+    }
+
+    /** Create an empty context, with no declarations. */
+    public static MethodAnalysisContext create() {
+      return new MethodAnalysisContext();
+    }
+  }
+
+  @AutoValue
+  abstract static class ParameterDescription {
+    public abstract Method getMethod();
+    public abstract int getIndex();
+    public abstract TypeDescriptor<?> getType();
+    public abstract List<Annotation> getAnnotations();
+
+    public static ParameterDescription of(
+        Method method, int index, TypeDescriptor<?> type, List<Annotation> annotations)
{
+      return new AutoValue_DoFnSignatures_ParameterDescription(method, index, type, annotations);
+    }
+
+    public static ParameterDescription of(
+        Method method, int index, TypeDescriptor<?> type, Annotation[] annotations)
{
+      return new AutoValue_DoFnSignatures_ParameterDescription(
+          method, index, type, Arrays.asList(annotations));
+    }
+  }
+
   /** Analyzes a given {@link DoFn} class and extracts its {@link DoFnSignature}. */
   private static DoFnSignature parseSignature(Class<? extends DoFn<?, ?>> fnClass)
{
-    DoFnSignature.Builder builder = DoFnSignature.builder();
+    DoFnSignature.Builder signatureBuilder = DoFnSignature.builder();
 
     ErrorReporter errors = new ErrorReporter(null, fnClass.getName());
     errors.checkArgument(DoFn.class.isAssignableFrom(fnClass), "Must be subtype of DoFn");
-    builder.setFnClass(fnClass);
+    signatureBuilder.setFnClass(fnClass);
 
     TypeDescriptor<? extends DoFn<?, ?>> fnT = TypeDescriptor.of(fnClass);
 
@@ -106,11 +230,9 @@ public class DoFnSignatures {
 
     // Find the state and timer declarations in advance of validating
     // method parameter lists
-    Map<String, StateDeclaration> stateDeclarations = analyzeStateDeclarations(errors,
fnClass);
-    builder.setStateDeclarations(stateDeclarations);
-
-    Map<String, TimerDeclaration> timerDeclarations = analyzeTimerDeclarations(errors,
fnClass);
-    builder.setTimerDeclarations(timerDeclarations);
+    FnAnalysisContext fnContext = FnAnalysisContext.create();
+    fnContext.addStateDeclarations(analyzeStateDeclarations(errors, fnClass).values());
+    fnContext.addTimerDeclarations(analyzeTimerDeclarations(errors, fnClass).values());
 
     Method processElementMethod =
         findAnnotatedMethod(errors, DoFn.ProcessElement.class, fnClass, true);
@@ -135,12 +257,12 @@ public class DoFnSignatures {
     for (Method onTimerMethod : onTimerMethods) {
       String id = onTimerMethod.getAnnotation(DoFn.OnTimer.class).value();
       errors.checkArgument(
-          timerDeclarations.containsKey(id),
+          fnContext.getTimerDeclarations().containsKey(id),
           "Callback %s is for for undeclared timer %s",
           onTimerMethod,
           id);
 
-      TimerDeclaration timerDecl = timerDeclarations.get(id);
+      TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id);
       errors.checkArgument(
           timerDecl.field().getDeclaringClass().equals(onTimerMethod.getDeclaringClass()),
           "Callback %s is for timer %s declared in a different class %s."
@@ -149,13 +271,14 @@ public class DoFnSignatures {
           id,
           timerDecl.field().getDeclaringClass().getCanonicalName());
 
-      onTimerMethodMap.put(id, OnTimerMethod.create(onTimerMethod, id, Collections.EMPTY_LIST));
+      onTimerMethodMap.put(
+          id, analyzeOnTimerMethod(errors, fnT, onTimerMethod, id, outputT, fnContext));
     }
-    builder.setOnTimerMethods(onTimerMethodMap);
+    signatureBuilder.setOnTimerMethods(onTimerMethodMap);
 
     // Check the converse - that all timers have a callback. This could be relaxed to only
     // those timers used in methods, once method parameter lists support timers.
-    for (TimerDeclaration decl : timerDeclarations.values()) {
+    for (TimerDeclaration decl : fnContext.getTimerDeclarations().values()) {
       errors.checkArgument(
           onTimerMethodMap.containsKey(decl.id()),
           "No callback registered via %s for timer %s",
@@ -172,30 +295,29 @@ public class DoFnSignatures {
             processElementMethod,
             inputT,
             outputT,
-            stateDeclarations,
-            timerDeclarations);
-    builder.setProcessElement(processElement);
+            fnContext);
+    signatureBuilder.setProcessElement(processElement);
 
     if (startBundleMethod != null) {
       ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod);
-      builder.setStartBundle(
+      signatureBuilder.setStartBundle(
           analyzeBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT));
     }
 
     if (finishBundleMethod != null) {
       ErrorReporter finishBundleErrors =
           errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod);
-      builder.setFinishBundle(
+      signatureBuilder.setFinishBundle(
           analyzeBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT));
     }
 
     if (setupMethod != null) {
-      builder.setSetup(
+      signatureBuilder.setSetup(
           analyzeLifecycleMethod(errors.forMethod(DoFn.Setup.class, setupMethod), setupMethod));
     }
 
     if (teardownMethod != null) {
-      builder.setTeardown(
+      signatureBuilder.setTeardown(
           analyzeLifecycleMethod(
               errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod));
     }
@@ -205,7 +327,7 @@ public class DoFnSignatures {
     if (getInitialRestrictionMethod != null) {
       getInitialRestrictionErrors =
           errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod);
-      builder.setGetInitialRestriction(
+      signatureBuilder.setGetInitialRestriction(
           getInitialRestriction =
               analyzeGetInitialRestrictionMethod(
                   getInitialRestrictionErrors, fnT, getInitialRestrictionMethod, inputT));
@@ -215,7 +337,7 @@ public class DoFnSignatures {
     if (splitRestrictionMethod != null) {
       ErrorReporter splitRestrictionErrors =
           errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod);
-      builder.setSplitRestriction(
+      signatureBuilder.setSplitRestriction(
           splitRestriction =
               analyzeSplitRestrictionMethod(
                   splitRestrictionErrors, fnT, splitRestrictionMethod, inputT));
@@ -225,7 +347,7 @@ public class DoFnSignatures {
     if (getRestrictionCoderMethod != null) {
       ErrorReporter getRestrictionCoderErrors =
           errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod);
-      builder.setGetRestrictionCoder(
+      signatureBuilder.setGetRestrictionCoder(
           getRestrictionCoder =
               analyzeGetRestrictionCoderMethod(
                   getRestrictionCoderErrors, fnT, getRestrictionCoderMethod));
@@ -234,13 +356,16 @@ public class DoFnSignatures {
     DoFnSignature.NewTrackerMethod newTracker = null;
     if (newTrackerMethod != null) {
       ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod);
-      builder.setNewTracker(
+      signatureBuilder.setNewTracker(
           newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod));
     }
 
-    builder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors));
+    signatureBuilder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors));
 
-    DoFnSignature signature = builder.build();
+    signatureBuilder.setStateDeclarations(fnContext.getStateDeclarations());
+    signatureBuilder.setTimerDeclarations(fnContext.getTimerDeclarations());
+
+    DoFnSignature signature = signatureBuilder.build();
 
     // Additional validation for splittable DoFn's.
     if (processElement.isSplittable()) {
@@ -452,14 +577,49 @@ public class DoFnSignatures {
   }
 
   @VisibleForTesting
+  static DoFnSignature.OnTimerMethod analyzeOnTimerMethod(
+      ErrorReporter errors,
+      TypeDescriptor<? extends DoFn<?, ?>> fnClass,
+      Method m,
+      String timerId,
+      TypeDescriptor<?> outputT,
+      FnAnalysisContext fnContext) {
+    errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
+
+    Type[] params = m.getGenericParameterTypes();
+
+    MethodAnalysisContext methodContext = MethodAnalysisContext.create();
+
+    List<DoFnSignature.Parameter> extraParameters = new ArrayList<>();
+    TypeDescriptor<?> expectedOutputReceiverT = outputReceiverTypeOf(outputT);
+    ErrorReporter onTimerErrors = errors.forMethod(DoFn.OnTimer.class, m);
+    for (int i = 0; i < params.length; ++i) {
+      extraParameters.add(
+          analyzeExtraParameter(
+              onTimerErrors,
+              fnContext,
+              methodContext,
+              fnClass,
+              ParameterDescription.of(
+                  m,
+                  i,
+                  fnClass.resolveType(params[i]),
+                  Arrays.asList(m.getParameterAnnotations()[i])),
+              null /* restriction type not applicable */,
+              expectedOutputReceiverT));
+    }
+
+    return DoFnSignature.OnTimerMethod.create(m, timerId, extraParameters);
+  }
+
+  @VisibleForTesting
   static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod(
       ErrorReporter errors,
       TypeDescriptor<? extends DoFn<?, ?>> fnClass,
       Method m,
       TypeDescriptor<?> inputT,
       TypeDescriptor<?> outputT,
-      Map<String, StateDeclaration> stateDeclarations,
-      Map<String, TimerDeclaration> timerDeclarations) {
+      FnAnalysisContext fnContext) {
     errors.checkArgument(
         void.class.equals(m.getReturnType())
             || DoFn.ProcessContinuation.class.equals(m.getReturnType()),
@@ -468,6 +628,8 @@ public class DoFnSignatures {
 
     TypeDescriptor<?> processContextT = doFnProcessContextTypeOf(inputT, outputT);
 
+    MethodAnalysisContext methodContext = MethodAnalysisContext.create();
+
     Type[] params = m.getGenericParameterTypes();
     TypeDescriptor<?> contextT = null;
     if (params.length > 0) {
@@ -478,192 +640,211 @@ public class DoFnSignatures {
         "Must take %s as the first argument",
         formatType(processContextT));
 
-    List<DoFnSignature.Parameter> extraParameters = new ArrayList<>();
-    Map<String, DoFnSignature.Parameter> stateParameters = new HashMap<>();
-    Map<String, DoFnSignature.Parameter> timerParameters = new HashMap<>();
-    TypeDescriptor<?> trackerT = null;
-
+    TypeDescriptor<?> trackerT = getTrackerType(fnClass, m);
     TypeDescriptor<?> expectedInputProviderT = inputProviderTypeOf(inputT);
     TypeDescriptor<?> expectedOutputReceiverT = outputReceiverTypeOf(outputT);
     for (int i = 1; i < params.length; ++i) {
-      TypeDescriptor<?> paramT = fnClass.resolveType(params[i]);
-      Class<?> rawType = paramT.getRawType();
-      if (rawType.equals(BoundedWindow.class)) {
-        errors.checkArgument(
-            !extraParameters.contains(DoFnSignature.Parameter.boundedWindow()),
-            "Multiple %s parameters",
-            BoundedWindow.class.getSimpleName());
-        extraParameters.add(DoFnSignature.Parameter.boundedWindow());
-      } else if (rawType.equals(DoFn.InputProvider.class)) {
-        errors.checkArgument(
-            !extraParameters.contains(DoFnSignature.Parameter.inputProvider()),
-            "Multiple %s parameters",
-            DoFn.InputProvider.class.getSimpleName());
-        errors.checkArgument(
-            paramT.equals(expectedInputProviderT),
-            "Wrong type of %s parameter: %s, should be %s",
-            DoFn.InputProvider.class.getSimpleName(),
-            formatType(paramT),
-            formatType(expectedInputProviderT));
-        extraParameters.add(DoFnSignature.Parameter.inputProvider());
-      } else if (rawType.equals(DoFn.OutputReceiver.class)) {
-        errors.checkArgument(
-            !extraParameters.contains(DoFnSignature.Parameter.outputReceiver()),
-            "Multiple %s parameters",
-            DoFn.OutputReceiver.class.getSimpleName());
-        errors.checkArgument(
-            paramT.equals(expectedOutputReceiverT),
-            "Wrong type of %s parameter: %s, should be %s",
-            DoFn.OutputReceiver.class.getSimpleName(),
-            formatType(paramT),
-            formatType(expectedOutputReceiverT));
-        extraParameters.add(DoFnSignature.Parameter.outputReceiver());
-      } else if (Timer.class.equals(rawType)) {
-        // m.getParameters() is not available until Java 8
-        Annotation[] annotations = m.getParameterAnnotations()[i];
-        String id = null;
-        for (Annotation anno : annotations) {
-          if (anno.annotationType().equals(DoFn.TimerId.class)) {
-            id = ((DoFn.TimerId) anno).value();
-            break;
-          }
-        }
-        errors.checkArgument(
-            id != null,
-            "%s parameter of type %s at index %s missing %s annotation",
-            fnClass.getRawType().getName(),
-            params[i],
-            i,
-            DoFn.TimerId.class.getSimpleName());
 
-        errors.checkArgument(
-            !timerParameters.containsKey(id),
-            "%s parameter of type %s at index %s duplicates %s(\"%s\") on other parameter",
-            fnClass.getRawType().getName(),
-            params[i],
-            i,
-            DoFn.TimerId.class.getSimpleName(),
-            id);
-
-        TimerDeclaration timerDecl = timerDeclarations.get(id);
-        errors.checkArgument(
-            timerDecl != null,
-            "%s parameter of type %s at index %s references undeclared %s \"%s\"",
-            fnClass.getRawType().getName(),
-            params[i],
-            i,
-            TimerId.class.getSimpleName(),
-            id);
+      Parameter extraParam =
+          analyzeExtraParameter(
+              errors.forMethod(DoFn.ProcessElement.class, m),
+              fnContext,
+              methodContext,
+              fnClass,
+              ParameterDescription.of(
+                  m,
+                  i,
+                  fnClass.resolveType(params[i]),
+                  Arrays.asList(m.getParameterAnnotations()[i])),
+              expectedInputProviderT,
+              expectedOutputReceiverT);
+
+      methodContext.addParameter(extraParam);
+    }
 
-        errors.checkArgument(
-            timerDecl.field().getDeclaringClass().equals(m.getDeclaringClass()),
-            "Method %s has %s parameter at index %s for timer %s"
-                + " declared in a different class %s."
-                + " Timers may be referenced only in the lexical scope where they are declared.",
-            m,
-            Timer.class.getSimpleName(),
-            i,
-            id,
-            timerDecl.field().getDeclaringClass().getName());
+    // A splittable DoFn can not have any other extra context parameters.
+    if (methodContext.getExtraParameters().contains(DoFnSignature.Parameter.restrictionTracker()))
{
+      errors.checkArgument(
+          methodContext.getExtraParameters().size() == 1,
+          "Splittable DoFn must not have any extra arguments, but has: %s",
+          trackerT,
+          methodContext.getExtraParameters());
+    }
 
-        DoFnSignature.Parameter.TimerParameter timerParameter = Parameter.timerParameter(timerDecl);
-        timerParameters.put(id, timerParameter);
-        extraParameters.add(timerParameter);
+    return DoFnSignature.ProcessElementMethod.create(
+        m,
+        methodContext.getExtraParameters(),
+        trackerT,
+        DoFn.ProcessContinuation.class.equals(m.getReturnType()));
+  }
 
-      } else if (RestrictionTracker.class.isAssignableFrom(rawType)) {
-        errors.checkArgument(
-            !extraParameters.contains(DoFnSignature.Parameter.restrictionTracker()),
-            "Multiple %s parameters",
-            RestrictionTracker.class.getSimpleName());
-        extraParameters.add(DoFnSignature.Parameter.restrictionTracker());
-        trackerT = paramT;
-      } else if (State.class.isAssignableFrom(rawType)) {
-        // m.getParameters() is not available until Java 8
-        Annotation[] annotations = m.getParameterAnnotations()[i];
-        String id = null;
-        for (Annotation anno : annotations) {
-          if (anno.annotationType().equals(DoFn.StateId.class)) {
-            id = ((DoFn.StateId) anno).value();
-            break;
-          }
-        }
-        errors.checkArgument(
-            id != null,
-            "%s parameter of type %s at index %s missing %s annotation",
-            fnClass.getRawType().getName(),
-            params[i],
-            i,
-            DoFn.StateId.class.getSimpleName());
+  private static Parameter analyzeExtraParameter(
+      ErrorReporter methodErrors,
+      FnAnalysisContext fnContext,
+      MethodAnalysisContext methodContext,
+      TypeDescriptor<? extends DoFn<?, ?>> fnClass,
+      ParameterDescription param,
+      TypeDescriptor<?> expectedInputProviderT,
+      TypeDescriptor<?> expectedOutputReceiverT) {
+    TypeDescriptor<?> paramT = param.getType();
+    Class<?> rawType = paramT.getRawType();
+
+    ErrorReporter paramErrors = methodErrors.forParameter(param);
+
+    if (rawType.equals(BoundedWindow.class)) {
+      methodErrors.checkArgument(
+          !methodContext.getExtraParameters().contains(Parameter.boundedWindow()),
+          "Multiple %s parameters",
+          BoundedWindow.class.getSimpleName());
+      return Parameter.boundedWindow();
+    } else if (rawType.equals(DoFn.InputProvider.class)) {
+      methodErrors.checkArgument(
+          !methodContext.getExtraParameters().contains(Parameter.inputProvider()),
+          "Multiple %s parameters",
+          DoFn.InputProvider.class.getSimpleName());
+      paramErrors.checkArgument(
+          paramT.equals(expectedInputProviderT),
+          "%s is for %s when it should be %s",
+          DoFn.InputProvider.class.getSimpleName(),
+          formatType(paramT),
+          formatType(expectedInputProviderT));
+      return Parameter.inputProvider();
+    } else if (rawType.equals(DoFn.OutputReceiver.class)) {
+      methodErrors.checkArgument(
+          !methodContext.getExtraParameters().contains(Parameter.outputReceiver()),
+          "Multiple %s parameters",
+          DoFn.OutputReceiver.class.getSimpleName());
+      paramErrors.checkArgument(
+          paramT.equals(expectedOutputReceiverT),
+          "%s is for %s when it should be %s",
+          DoFn.OutputReceiver.class.getSimpleName(),
+          formatType(paramT),
+          formatType(expectedOutputReceiverT));
+      return Parameter.outputReceiver();
+
+    } else if (RestrictionTracker.class.isAssignableFrom(rawType)) {
+      methodErrors.checkArgument(
+          !methodContext.getExtraParameters().contains(Parameter.restrictionTracker()),
+          "Multiple %s parameters",
+          RestrictionTracker.class.getSimpleName());
+      return Parameter.restrictionTracker();
+
+    } else if (rawType.equals(Timer.class)) {
+      // m.getParameters() is not available until Java 8
+      String id = getTimerId(param.getAnnotations());
+
+      paramErrors.checkArgument(
+          id != null,
+          "%s missing %s annotation",
+          Timer.class.getSimpleName(),
+          TimerId.class.getSimpleName());
+
+      paramErrors.checkArgument(
+          !methodContext.getTimerParameters().containsKey(id),
+          "duplicate %s: \"%s\"",
+          TimerId.class.getSimpleName(),
+          id);
 
-        errors.checkArgument(
-            !stateParameters.containsKey(id),
-            "%s parameter of type %s at index %s duplicates %s(\"%s\") on other parameter",
-            fnClass.getRawType().getName(),
-            params[i],
-            i,
-            DoFn.StateId.class.getSimpleName(),
-            id);
+      TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id);
+      paramErrors.checkArgument(
+          timerDecl != null,
+          "reference to undeclared %s: \"%s\"",
+          TimerId.class.getSimpleName(),
+          id);
 
-        // By static typing this is already a well-formed State subclass
-        TypeDescriptor<? extends State> stateType =
-            (TypeDescriptor<? extends State>)
-                TypeDescriptor.of(fnClass.getType())
-                    .resolveType(params[i]);
+      paramErrors.checkArgument(
+          timerDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()),
+          "%s %s declared in a different class %s."
+              + " Timers may be referenced only in the lexical scope where they are declared.",
+          TimerId.class.getSimpleName(),
+          id,
+          timerDecl.field().getDeclaringClass().getName());
+
+      return Parameter.timerParameter(timerDecl);
+
+    } else if (State.class.isAssignableFrom(rawType)) {
+      // m.getParameters() is not available until Java 8
+      String id = getStateId(param.getAnnotations());
+      paramErrors.checkArgument(
+          id != null,
+          "missing %s annotation",
+          DoFn.StateId.class.getSimpleName());
+
+      paramErrors.checkArgument(
+          !methodContext.getStateParameters().containsKey(id),
+          "duplicate %s: \"%s\"",
+          DoFn.StateId.class.getSimpleName(),
+          id);
 
-        StateDeclaration stateDecl = stateDeclarations.get(id);
-        errors.checkArgument(
-            stateDecl != null,
-            "%s parameter of type %s at index %s references undeclared %s \"%s\"",
-            fnClass.getRawType().getName(),
-            params[i],
-            i,
-            DoFn.StateId.class.getSimpleName(),
-            id);
+      // By static typing this is already a well-formed State subclass
+      TypeDescriptor<? extends State> stateType = (TypeDescriptor<? extends State>)
param.getType();
 
-        errors.checkArgument(
-            stateDecl.stateType().equals(stateType),
-            "%s parameter at index %s has type %s but is a reference to StateId %s of type
%s",
-            fnClass.getRawType().getName(),
-            i,
-            params[i],
-            id,
-            stateDecl.stateType());
+      StateDeclaration stateDecl = fnContext.getStateDeclarations().get(id);
+      paramErrors.checkArgument(
+          stateDecl != null,
+          "reference to undeclared %s: \"%s\"",
+          DoFn.StateId.class.getSimpleName(),
+          id);
 
-        errors.checkArgument(
-            stateDecl.field().getDeclaringClass().equals(m.getDeclaringClass()),
-            "Method %s has State parameter at index %s for state %s"
-                + " declared in a different class %s."
-                + " State may be referenced only in the class where it is declared.",
-            m,
-            i,
-            id,
-            stateDecl.field().getDeclaringClass().getName());
-
-        DoFnSignature.Parameter.StateParameter stateParameter = Parameter.stateParameter(stateDecl);
-        stateParameters.put(id, stateParameter);
-        extraParameters.add(stateParameter);
-      } else {
-        List<String> allowedParamTypes =
-            Arrays.asList(
-                formatType(new TypeDescriptor<BoundedWindow>() {}),
-                formatType(new TypeDescriptor<RestrictionTracker<?>>() {}));
-        errors.throwIllegalArgument(
-            "%s is not a valid context parameter. Should be one of %s",
-            formatType(paramT), allowedParamTypes);
+      paramErrors.checkArgument(
+          stateDecl.stateType().equals(stateType),
+          "reference to %s %s with different type %s",
+          StateId.class.getSimpleName(),
+          id,
+          stateDecl.stateType());
+
+      paramErrors.checkArgument(
+          stateDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()),
+          "%s %s declared in a different class %s."
+              + " State may be referenced only in the class where it is declared.",
+          StateId.class.getSimpleName(),
+          id,
+          stateDecl.field().getDeclaringClass().getName());
+
+      return Parameter.stateParameter(stateDecl);
+    } else {
+      List<String> allowedParamTypes =
+          Arrays.asList(
+              formatType(new TypeDescriptor<BoundedWindow>() {}),
+              formatType(new TypeDescriptor<RestrictionTracker<?>>() {}));
+      paramErrors.throwIllegalArgument(
+          "%s is not a valid context parameter. Should be one of %s",
+          formatType(paramT), allowedParamTypes);
+      // Unreachable
+      return null;
+    }
+  }
+
+  @Nullable
+  private static String getTimerId(List<Annotation> annotations) {
+    for (Annotation anno : annotations) {
+      if (anno.annotationType().equals(DoFn.TimerId.class)) {
+        return ((DoFn.TimerId) anno).value();
       }
     }
+    return null;
+  }
 
-    // A splittable DoFn can not have any other extra context parameters.
-    if (extraParameters.contains(DoFnSignature.Parameter.restrictionTracker())) {
-      errors.checkArgument(
-          extraParameters.size() == 1,
-          "Splittable DoFn must not have any extra arguments apart from BoundedWindow, but
has: %s",
-          trackerT,
-          extraParameters);
+  @Nullable
+  private static String getStateId(List<Annotation> annotations) {
+    for (Annotation anno : annotations) {
+      if (anno.annotationType().equals(DoFn.StateId.class)) {
+        return ((DoFn.StateId) anno).value();
+      }
     }
+    return null;
+  }
 
-    return DoFnSignature.ProcessElementMethod.create(
-        m, extraParameters, trackerT, DoFn.ProcessContinuation.class.equals(m.getReturnType()));
+  @Nullable
+  private static TypeDescriptor<?> getTrackerType(TypeDescriptor<?> fnClass,
Method method) {
+    Type[] params = method.getGenericParameterTypes();
+    for (int i = 0; i < params.length; i++) {
+      TypeDescriptor<?> paramT = fnClass.resolveType(params[i]);
+      if (RestrictionTracker.class.isAssignableFrom(paramT.getRawType())) {
+        return paramT;
+      }
+    }
+    return null;
   }
 
   @VisibleForTesting
@@ -905,7 +1086,7 @@ public class DoFnSignatures {
     return matches;
   }
 
-  private static ImmutableMap<String, DoFnSignature.StateDeclaration> analyzeStateDeclarations(
+  private static Map<String, DoFnSignature.StateDeclaration> analyzeStateDeclarations(
       ErrorReporter errors,
       Class<?> fnClazz) {
 
@@ -1015,6 +1196,14 @@ public class DoFnSignatures {
               annotation.getSimpleName(), (method == null) ? "(absent)" : format(method)));
     }
 
+    ErrorReporter forParameter(ParameterDescription param) {
+      return new ErrorReporter(
+          this,
+          String.format(
+              "parameter of type %s at index %s",
+              param.getType(), param.getIndex()));
+    }
+
     void throwIllegalArgument(String message, Object... args) {
       throw new IllegalArgumentException(label + ": " + String.format(message, args));
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8bf6d92c/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java
index 329a099..6cbc95e 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java
@@ -96,9 +96,9 @@ public class DoFnSignaturesProcessElementTest {
   @Test
   public void testBadGenericsTwoArgs() throws Exception {
     thrown.expect(IllegalArgumentException.class);
-    thrown.expectMessage(
-        "Wrong type of OutputReceiver parameter: "
-            + "OutputReceiver<Integer>, should be OutputReceiver<String>");
+    thrown.expectMessage("OutputReceiver<Integer>");
+    thrown.expectMessage("should be");
+    thrown.expectMessage("OutputReceiver<String>");
 
     analyzeProcessElementMethod(
         new AnonymousMethod() {
@@ -112,9 +112,9 @@ public class DoFnSignaturesProcessElementTest {
   @Test
   public void testBadGenericWildCards() throws Exception {
     thrown.expect(IllegalArgumentException.class);
-    thrown.expectMessage(
-        "Wrong type of OutputReceiver parameter: "
-            + "OutputReceiver<? super Integer>, should be OutputReceiver<String>");
+    thrown.expectMessage("OutputReceiver<? super Integer>");
+    thrown.expectMessage("should be");
+    thrown.expectMessage("OutputReceiver<String>");
 
     analyzeProcessElementMethod(
         new AnonymousMethod() {
@@ -137,9 +137,9 @@ public class DoFnSignaturesProcessElementTest {
   @Test
   public void testBadTypeVariables() throws Exception {
     thrown.expect(IllegalArgumentException.class);
-    thrown.expectMessage(
-        "Wrong type of OutputReceiver parameter: "
-            + "OutputReceiver<InputT>, should be OutputReceiver<OutputT>");
+    thrown.expectMessage("OutputReceiver<InputT>");
+    thrown.expectMessage("should be");
+    thrown.expectMessage("OutputReceiver<OutputT>");
 
     DoFnSignatures.INSTANCE.getSignature(BadTypeVariables.class);
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8bf6d92c/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
index 573701b..0751b59 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
@@ -88,7 +88,6 @@ public class DoFnSignaturesSplittableDoFnTest {
   public void testSplittableProcessElementMustNotHaveOtherParams() throws Exception {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("must not have any extra arguments");
-    thrown.expectMessage("BoundedWindow");
 
     DoFnSignature.ProcessElementMethod signature =
         analyzeProcessElementMethod(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8bf6d92c/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index 52ecb2a..4187e0a 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -21,6 +21,7 @@ import static org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.err
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
@@ -29,11 +30,11 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.OnTimer;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.TimeDomain;
 import org.apache.beam.sdk.util.Timer;
 import org.apache.beam.sdk.util.TimerSpec;
@@ -249,7 +250,7 @@ public class DoFnSignaturesTest {
   @Test
   public void testTimerParameterDuplicate() throws Exception {
     thrown.expect(IllegalArgumentException.class);
-    thrown.expectMessage("duplicates");
+    thrown.expectMessage("duplicate");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
     thrown.expectMessage("index 2");
@@ -291,6 +292,28 @@ public class DoFnSignaturesTest {
   }
 
   @Test
+  public void testWindowParamOnTimer() throws Exception {
+    final String timerId = "some-timer-id";
+
+    DoFnSignature sig =
+        DoFnSignatures.INSTANCE.getSignature(new DoFn<String, String>() {
+          @TimerId(timerId)
+          private final TimerSpec myfield1 = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+          @ProcessElement
+          public void process(ProcessContext c) {}
+
+          @OnTimer(timerId)
+          public void onTimer(BoundedWindow w) {}
+        }.getClass());
+
+    assertThat(sig.onTimerMethods().get(timerId).extraParameters().size(), equalTo(1));
+    assertThat(
+        sig.onTimerMethods().get(timerId).extraParameters().get(0),
+        instanceOf(DoFnSignature.Parameter.BoundedWindowParameter.class));
+  }
+
+  @Test
   public void testDeclAndUsageOfTimerInSuperclass() throws Exception {
     DoFnSignature sig =
         DoFnSignatures.INSTANCE.getSignature(new DoFnOverridingAbstractTimerUse().getClass());
@@ -525,7 +548,7 @@ public class DoFnSignaturesTest {
   @Test
   public void testStateParameterDuplicate() throws Exception {
     thrown.expect(IllegalArgumentException.class);
-    thrown.expectMessage("duplicates");
+    thrown.expectMessage("duplicate");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
     thrown.expectMessage("index 2");
@@ -549,7 +572,8 @@ public class DoFnSignaturesTest {
   public void testStateParameterWrongStateType() throws Exception {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("WatermarkHoldState");
-    thrown.expectMessage("but is a reference to");
+    thrown.expectMessage("reference to");
+    thrown.expectMessage("different type");
     thrown.expectMessage("ValueState");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
@@ -572,7 +596,8 @@ public class DoFnSignaturesTest {
   public void testStateParameterWrongGenericType() throws Exception {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("ValueState<java.lang.String>");
-    thrown.expectMessage("but is a reference to");
+    thrown.expectMessage("reference to");
+    thrown.expectMessage("different type");
     thrown.expectMessage("ValueState<java.lang.Integer>");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/8bf6d92c/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java
index 49e2ba7..b7d137a 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTestUtils.java
@@ -18,9 +18,9 @@
 package org.apache.beam.sdk.transforms.reflect;
 
 import java.lang.reflect.Method;
-import java.util.Collections;
 import java.util.NoSuchElementException;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures.FnAnalysisContext;
 import org.apache.beam.sdk.values.TypeDescriptor;
 
 /** Utilities for use in {@link DoFnSignatures} tests. */
@@ -61,7 +61,6 @@ class DoFnSignaturesTestUtils {
         method.getMethod(),
         TypeDescriptor.of(Integer.class),
         TypeDescriptor.of(String.class),
-        Collections.EMPTY_MAP,
-        Collections.EMPTY_MAP);
+        FnAnalysisContext.create());
   }
 }



Mime
View raw message