beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From aljos...@apache.org
Subject [1/2] beam git commit: [BEAM-1036] Support for new State API in Flink Batch Runner
Date Tue, 28 Feb 2017 10:11:26 GMT
Repository: beam
Updated Branches:
  refs/heads/master 13db84bb0 -> 2a2337460


[BEAM-1036] Support for new State API in Flink Batch Runner


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ec5a8262
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ec5a8262
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ec5a8262

Branch: refs/heads/master
Commit: ec5a82620916b2297a8b349a605af8fadeb2ceb7
Parents: 13db84b
Author: JingsongLi <lzljs3620320@aliyun.com>
Authored: Tue Feb 28 02:07:33 2017 +0800
Committer: Aljoscha Krettek <aljoscha.krettek@gmail.com>
Committed: Tue Feb 28 11:02:48 2017 +0100

----------------------------------------------------------------------
 runners/flink/runner/pom.xml                    |   1 -
 .../flink/FlinkBatchTransformTranslators.java   | 130 +++++++++++-------
 .../functions/FlinkDoFnFunction.java            |  52 ++++++-
 .../functions/FlinkMultiOutputDoFnFunction.java | 131 ------------------
 .../FlinkMultiOutputPruningFunction.java        |   2 +-
 .../functions/FlinkStatefulDoFnFunction.java    | 134 +++++++++++++++++++
 6 files changed, 265 insertions(+), 185 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/pom.xml
----------------------------------------------------------------------
diff --git a/runners/flink/runner/pom.xml b/runners/flink/runner/pom.xml
index c00b328..8cc65b0 100644
--- a/runners/flink/runner/pom.xml
+++ b/runners/flink/runner/pom.xml
@@ -55,7 +55,6 @@
                   <groups>org.apache.beam.sdk.testing.RunnableOnService</groups>
                   <excludedGroups>
                     org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders,
-                    org.apache.beam.sdk.testing.UsesStatefulParDo,
                     org.apache.beam.sdk.testing.UsesTimersInParDo,
                     org.apache.beam.sdk.testing.UsesSplittableParDo,
                     org.apache.beam.sdk.testing.UsesAttemptedMetrics,

http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
index 99651c3..ed2f4aa 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java
@@ -32,10 +32,10 @@ import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction;
 import org.apache.beam.runners.flink.translation.functions.FlinkMergingNonShuffleReduceFunction;
 import org.apache.beam.runners.flink.translation.functions.FlinkMergingPartialReduceFunction;
 import org.apache.beam.runners.flink.translation.functions.FlinkMergingReduceFunction;
-import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputDoFnFunction;
 import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction;
 import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction;
 import org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction;
+import org.apache.beam.runners.flink.translation.functions.FlinkStatefulDoFnFunction;
 import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
 import org.apache.beam.runners.flink.translation.types.KvKeySelector;
 import org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat;
@@ -498,19 +498,9 @@ class FlinkBatchTransformTranslators {
     }
   }
 
-  private static void rejectStateAndTimers(DoFn<?, ?> doFn) {
+  private static void rejectTimers(DoFn<?, ?> doFn) {
     DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
 
-    if (signature.stateDeclarations().size() > 0) {
-      throw new UnsupportedOperationException(
-          String.format(
-              "Found %s annotations on %s, but %s cannot yet be used with state in the %s.",
-              DoFn.StateId.class.getSimpleName(),
-              doFn.getClass().getName(),
-              DoFn.class.getSimpleName(),
-              FlinkRunner.class.getSimpleName()));
-    }
-
     if (signature.timerDeclarations().size() > 0) {
       throw new UnsupportedOperationException(
           String.format(
@@ -527,13 +517,14 @@ class FlinkBatchTransformTranslators {
           ParDo.Bound<InputT, OutputT>> {
 
     @Override
+    @SuppressWarnings("unchecked")
     public void translateNode(
         ParDo.Bound<InputT, OutputT> transform,
 
         FlinkBatchTranslationContext context) {
       DoFn<InputT, OutputT> doFn = transform.getFn();
       rejectSplittable(doFn);
-      rejectStateAndTimers(doFn);
+      rejectTimers(doFn);
 
       DataSet<WindowedValue<InputT>> inputDataSet =
           context.getInputDataSet(context.getInput(transform));
@@ -550,23 +541,48 @@ class FlinkBatchTransformTranslators {
         sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal());
       }
 
-      FlinkDoFnFunction<InputT, OutputT> doFnWrapper =
-          new FlinkDoFnFunction<>(
-              doFn,
-              context.getOutput(transform).getWindowingStrategy(),
-              sideInputStrategies,
-              context.getPipelineOptions());
+      WindowingStrategy<?, ?> windowingStrategy =
+          context.getOutput(transform).getWindowingStrategy();
+
+      SingleInputUdfOperator<WindowedValue<InputT>, WindowedValue<OutputT>,
?> outputDataSet;
+      DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
+      if (signature.stateDeclarations().size() > 0
+          || signature.timerDeclarations().size() > 0) {
+
+        // Based on the fact that the signature is stateful, DoFnSignatures ensures
+        // that it is also keyed
+        KvCoder<?, InputT> inputCoder =
+            (KvCoder<?, InputT>) context.getInput(transform).getCoder();
+
+        FlinkStatefulDoFnFunction<?, ?, OutputT> doFnWrapper = new FlinkStatefulDoFnFunction<>(
+            (DoFn) doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(),
+            null, new TupleTag<OutputT>()
+        );
 
-      MapPartitionOperator<WindowedValue<InputT>, WindowedValue<OutputT>>
outputDataSet =
-          new MapPartitionOperator<>(
-              inputDataSet,
-              typeInformation,
-              doFnWrapper,
-              transform.getName());
+        Grouping<WindowedValue<InputT>> grouping =
+            inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder()));
+
+        outputDataSet = new GroupReduceOperator(
+            grouping, typeInformation, doFnWrapper, transform.getName());
+
+      } else {
+        FlinkDoFnFunction<InputT, OutputT> doFnWrapper =
+            new FlinkDoFnFunction<>(
+                doFn,
+                windowingStrategy,
+                sideInputStrategies,
+                context.getPipelineOptions(),
+                null, new TupleTag<OutputT>());
+
+        outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper,
+            transform.getName());
+
+      }
 
       transformSideInputs(sideInputs, outputDataSet, context);
 
       context.setOutputDataSet(context.getOutput(transform), outputDataSet);
+
     }
   }
 
@@ -575,12 +591,13 @@ class FlinkBatchTransformTranslators {
           ParDo.BoundMulti<InputT, OutputT>> {
 
     @Override
+    @SuppressWarnings("unchecked")
     public void translateNode(
         ParDo.BoundMulti<InputT, OutputT> transform,
         FlinkBatchTranslationContext context) {
       DoFn<InputT, OutputT> doFn = transform.getFn();
       rejectSplittable(doFn);
-      rejectStateAndTimers(doFn);
+      rejectTimers(doFn);
       DataSet<WindowedValue<InputT>> inputDataSet =
           context.getInputDataSet(context.getInput(transform));
 
@@ -633,36 +650,57 @@ class FlinkBatchTransformTranslators {
         sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal());
       }
 
-      @SuppressWarnings("unchecked")
-      FlinkMultiOutputDoFnFunction<InputT, OutputT> doFnWrapper =
-          new FlinkMultiOutputDoFnFunction(
-              doFn,
-              windowingStrategy,
-              sideInputStrategies,
-              context.getPipelineOptions(),
-              outputMap,
-              transform.getMainOutputTag());
-
-      MapPartitionOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>>
taggedDataSet =
-          new MapPartitionOperator<>(
-              inputDataSet,
-              typeInformation,
-              doFnWrapper,
-              transform.getName());
-
-      transformSideInputs(sideInputs, taggedDataSet, context);
+      SingleInputUdfOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>,
?> outputDataSet;
+      DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
+      if (signature.stateDeclarations().size() > 0
+          || signature.timerDeclarations().size() > 0) {
+
+        // Based on the fact that the signature is stateful, DoFnSignatures ensures
+        // that it is also keyed
+        KvCoder<?, InputT> inputCoder =
+            (KvCoder<?, InputT>) context.getInput(transform).getCoder();
+
+        FlinkStatefulDoFnFunction<?, ?, OutputT> doFnWrapper = new FlinkStatefulDoFnFunction<>(
+            (DoFn) doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(),
+            outputMap, transform.getMainOutputTag()
+        );
+
+        Grouping<WindowedValue<InputT>> grouping =
+            inputDataSet.groupBy(new KvKeySelector(inputCoder.getKeyCoder()));
+
+        outputDataSet =
+            new GroupReduceOperator(grouping, typeInformation, doFnWrapper, transform.getName());
+
+      } else {
+        FlinkDoFnFunction<InputT, RawUnionValue> doFnWrapper =
+            new FlinkDoFnFunction(
+                doFn,
+                windowingStrategy,
+                sideInputStrategies,
+                context.getPipelineOptions(),
+                outputMap,
+                transform.getMainOutputTag());
+
+        outputDataSet = new MapPartitionOperator<>(
+            inputDataSet, typeInformation,
+            doFnWrapper, transform.getName());
+
+      }
+
+      transformSideInputs(sideInputs, outputDataSet, context);
 
       for (TaggedPValue output : outputs) {
         pruneOutput(
-            taggedDataSet,
+            outputDataSet,
             context,
             outputMap.get(output.getTag()),
             (PCollection) output.getValue());
       }
+
     }
 
     private <T> void pruneOutput(
-        MapPartitionOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>>
taggedDataSet,
+        DataSet<WindowedValue<RawUnionValue>> taggedDataSet,
         FlinkBatchTranslationContext context,
         int integerTag,
         PCollection<T> collection) {

http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
index 7081aad..9687478 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java
@@ -24,6 +24,7 @@ import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.join.RawUnionValue;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.util.WindowedValue;
@@ -38,6 +39,10 @@ import org.apache.flink.util.Collector;
 /**
  * Encapsulates a {@link DoFn}
  * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}.
+ *
+ * <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index
+ * and must tag all outputs with the output number. Afterwards a filter will filter out
+ * those elements that are not to be in a specific output.
  */
 public class FlinkDoFnFunction<InputT, OutputT>
     extends RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<OutputT>>
{
@@ -49,18 +54,25 @@ public class FlinkDoFnFunction<InputT, OutputT>
 
   private final WindowingStrategy<?, ?> windowingStrategy;
 
+  private final Map<TupleTag<?>, Integer> outputMap;
+  private final TupleTag<OutputT> mainOutputTag;
+
   private transient DoFnInvoker<InputT, OutputT> doFnInvoker;
 
   public FlinkDoFnFunction(
       DoFn<InputT, OutputT> doFn,
       WindowingStrategy<?, ?> windowingStrategy,
       Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
-      PipelineOptions options) {
+      PipelineOptions options,
+      Map<TupleTag<?>, Integer> outputMap,
+      TupleTag<OutputT> mainOutputTag) {
 
     this.doFn = doFn;
     this.sideInputs = sideInputs;
     this.serializedOptions = new SerializedPipelineOptions(options);
     this.windowingStrategy = windowingStrategy;
+    this.outputMap = outputMap;
+    this.mainOutputTag = mainOutputTag;
 
   }
 
@@ -71,12 +83,21 @@ public class FlinkDoFnFunction<InputT, OutputT>
 
     RuntimeContext runtimeContext = getRuntimeContext();
 
+    DoFnRunners.OutputManager outputManager;
+    if (outputMap == null) {
+      outputManager = new FlinkDoFnFunction.DoFnOutputManager(out);
+    } else {
+      // it has some sideOutputs
+      outputManager =
+          new FlinkDoFnFunction.MultiDoFnOutputManager((Collector) out, outputMap);
+    }
+
     DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner(
         serializedOptions.getPipelineOptions(), doFn,
         new FlinkSideInputReader(sideInputs, runtimeContext),
-        new DoFnOutputManager(out),
-        new TupleTag<OutputT>() {
-        },
+        outputManager,
+        mainOutputTag,
+        // see SimpleDoFnRunner, just use it to limit number of side outputs
         Collections.<TupleTag<?>>emptyList(),
         new FlinkNoOpStepContext(),
         new FlinkAggregatorFactory(runtimeContext),
@@ -102,12 +123,12 @@ public class FlinkDoFnFunction<InputT, OutputT>
     doFnInvoker.invokeTeardown();
   }
 
-  private class DoFnOutputManager
+  static class DoFnOutputManager
       implements DoFnRunners.OutputManager {
 
     private Collector collector;
 
-    DoFnOutputManager(Collector<WindowedValue<OutputT>> collector) {
+    DoFnOutputManager(Collector collector) {
       this.collector = collector;
     }
 
@@ -118,4 +139,23 @@ public class FlinkDoFnFunction<InputT, OutputT>
     }
   }
 
+  static class MultiDoFnOutputManager
+      implements DoFnRunners.OutputManager {
+
+    private Collector<WindowedValue<RawUnionValue>> collector;
+    private Map<TupleTag<?>, Integer> outputMap;
+
+    MultiDoFnOutputManager(Collector<WindowedValue<RawUnionValue>> collector,
+                      Map<TupleTag<?>, Integer> outputMap) {
+      this.collector = collector;
+      this.outputMap = outputMap;
+    }
+
+    @Override
+    public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+      collector.collect(WindowedValue.of(new RawUnionValue(outputMap.get(tag), output.getValue()),
+          output.getTimestamp(), output.getWindows(), output.getPane()));
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
deleted file mode 100644
index 27ba5ac..0000000
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.flink.translation.functions;
-
-import java.util.Collections;
-import java.util.Map;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners;
-import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.join.RawUnionValue;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.WindowingStrategy;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.flink.api.common.functions.RichMapPartitionFunction;
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.util.Collector;
-
-/**
- * Encapsulates a {@link DoFn} that can emit to multiple
- * outputs inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}.
- *
- * <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index
- * and must tag all outputs with the output number. Afterwards a filter will filter out
- * those elements that are not to be in a specific output.
- */
-public class FlinkMultiOutputDoFnFunction<InputT, OutputT>
-    extends RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<RawUnionValue>>
{
-
-  private final DoFn<InputT, OutputT> doFn;
-  private final SerializedPipelineOptions serializedOptions;
-
-  private final Map<TupleTag<?>, Integer> outputMap;
-
-  private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
-  private final WindowingStrategy<?, ?> windowingStrategy;
-  private TupleTag<OutputT> mainOutputTag;
-  private transient DoFnInvoker<InputT, OutputT> doFnInvoker;
-
-  public FlinkMultiOutputDoFnFunction(
-      DoFn<InputT, OutputT> doFn,
-      WindowingStrategy<?, ?> windowingStrategy,
-      Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
-      PipelineOptions options,
-      Map<TupleTag<?>, Integer> outputMap,
-      TupleTag<OutputT> mainOutputTag) {
-    this.doFn = doFn;
-    this.serializedOptions = new SerializedPipelineOptions(options);
-    this.outputMap = outputMap;
-
-    this.windowingStrategy = windowingStrategy;
-    this.sideInputs = sideInputs;
-    this.mainOutputTag = mainOutputTag;
-  }
-
-  @Override
-  public void mapPartition(
-      Iterable<WindowedValue<InputT>> values,
-      Collector<WindowedValue<RawUnionValue>> out) throws Exception {
-
-    RuntimeContext runtimeContext = getRuntimeContext();
-
-    DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner(
-        serializedOptions.getPipelineOptions(), doFn,
-        new FlinkSideInputReader(sideInputs, runtimeContext),
-        new DoFnOutputManager(out),
-        mainOutputTag,
-        // see SimpleDoFnRunner, just use it to limit number of side outputs
-        Collections.<TupleTag<?>>emptyList(),
-        new FlinkNoOpStepContext(),
-        new FlinkAggregatorFactory(runtimeContext),
-        windowingStrategy);
-
-    doFnRunner.startBundle();
-
-    for (WindowedValue<InputT> value : values) {
-      doFnRunner.processElement(value);
-    }
-
-    doFnRunner.finishBundle();
-
-  }
-
-  @Override
-  public void open(Configuration parameters) throws Exception {
-    doFnInvoker = DoFnInvokers.invokerFor(doFn);
-    doFnInvoker.invokeSetup();
-  }
-
-  @Override
-  public void close() throws Exception {
-    doFnInvoker.invokeTeardown();
-  }
-
-  private class DoFnOutputManager
-      implements DoFnRunners.OutputManager {
-
-    private Collector<WindowedValue<RawUnionValue>> collector;
-
-    DoFnOutputManager(Collector<WindowedValue<RawUnionValue>> collector) {
-      this.collector = collector;
-    }
-
-    @Override
-    public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
-      collector.collect(WindowedValue.of(new RawUnionValue(outputMap.get(tag), output.getValue()),
-          output.getTimestamp(), output.getWindows(), output.getPane()));
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java
index b72750a..9071cc5 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java
@@ -25,7 +25,7 @@ import org.apache.flink.util.Collector;
 /**
  * A {@link FlatMapFunction} function that filters out those elements that don't belong in
this
  * output. We need this to implement MultiOutput ParDo functions in combination with
- * {@link FlinkMultiOutputDoFnFunction}.
+ * {@link FlinkDoFnFunction}.
  */
 public class FlinkMultiOutputPruningFunction<T>
     implements FlatMapFunction<WindowedValue<RawUnionValue>, WindowedValue<T>>
{

http://git-wip-us.apache.org/repos/asf/beam/blob/ec5a8262/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
new file mode 100644
index 0000000..fca7691
--- /dev/null
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.functions;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Map;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.flink.api.common.functions.RichGroupReduceFunction;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.util.Collector;
+
+/**
+ * A {@link RichGroupReduceFunction} for stateful {@link ParDo} in Flink Batch Runner.
+ */
+public class FlinkStatefulDoFnFunction<K, V, OutputT>
+    extends RichGroupReduceFunction<WindowedValue<KV<K, V>>, WindowedValue<OutputT>>
{
+
+  private final DoFn<KV<K, V>, OutputT> dofn;
+  private final WindowingStrategy<?, ?> windowingStrategy;
+  private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
+  private final SerializedPipelineOptions serializedOptions;
+  private final Map<TupleTag<?>, Integer> outputMap;
+  private final TupleTag<OutputT> mainOutputTag;
+  private transient DoFnInvoker doFnInvoker;
+
+  public FlinkStatefulDoFnFunction(
+      DoFn<KV<K, V>, OutputT> dofn,
+      WindowingStrategy<?, ?> windowingStrategy,
+      Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
+      PipelineOptions pipelineOptions,
+      Map<TupleTag<?>, Integer> outputMap,
+      TupleTag<OutputT> mainOutputTag) {
+
+    this.dofn = dofn;
+    this.windowingStrategy = windowingStrategy;
+    this.sideInputs = sideInputs;
+    this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+    this.outputMap = outputMap;
+    this.mainOutputTag = mainOutputTag;
+  }
+
+  @Override
+  public void reduce(
+      Iterable<WindowedValue<KV<K, V>>> values,
+      Collector<WindowedValue<OutputT>> out) throws Exception {
+    RuntimeContext runtimeContext = getRuntimeContext();
+
+    DoFnRunners.OutputManager outputManager;
+    if (outputMap == null) {
+      outputManager = new FlinkDoFnFunction.DoFnOutputManager(out);
+    } else {
+      // it has some sideOutputs
+      outputManager =
+          new FlinkDoFnFunction.MultiDoFnOutputManager((Collector) out, outputMap);
+    }
+
+    final Iterator<WindowedValue<KV<K, V>>> iterator = values.iterator();
+
+    // get the first value, we need this for initializing the state internals with the key.
+    // we are guaranteed to have a first value, otherwise reduce() would not have been called.
+    WindowedValue<KV<K, V>> currentValue = iterator.next();
+    final K key = currentValue.getValue().getKey();
+
+    final InMemoryStateInternals<K> stateInternals = InMemoryStateInternals.forKey(key);
+    DoFnRunner<KV<K, V>, OutputT> doFnRunner = DoFnRunners.simpleRunner(
+        serializedOptions.getPipelineOptions(), dofn,
+        new FlinkSideInputReader(sideInputs, runtimeContext),
+        outputManager,
+        mainOutputTag,
+        // see SimpleDoFnRunner, just use it to limit number of side outputs
+        Collections.<TupleTag<?>>emptyList(),
+        new FlinkNoOpStepContext() {
+          @Override
+          public StateInternals<?> stateInternals() {
+            return stateInternals;
+          }
+        },
+        new FlinkAggregatorFactory(runtimeContext),
+        windowingStrategy);
+
+    doFnRunner.startBundle();
+
+    doFnRunner.processElement(currentValue);
+    while (iterator.hasNext()) {
+      currentValue = iterator.next();
+      doFnRunner.processElement(currentValue);
+    }
+
+    doFnRunner.finishBundle();
+  }
+
+  @Override
+  public void open(Configuration parameters) throws Exception {
+    doFnInvoker = DoFnInvokers.invokerFor(dofn);
+    doFnInvoker.invokeSetup();
+  }
+
+  @Override
+  public void close() throws Exception {
+    doFnInvoker.invokeTeardown();
+  }
+
+}


Mime
View raw message