beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [5/9] beam git commit: Explicitly GBK before stateful ParDo in Dataflow batch
Date Wed, 08 Mar 2017 22:25:20 GMT
Explicitly GBK before stateful ParDo in Dataflow batch


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e2afa29
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e2afa29
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e2afa29

Branch: refs/heads/release-0.6.0
Commit: 5e2afa29a3a0fe93e662b2fe7173c1641c253cd5
Parents: 92c5b5b
Author: Kenneth Knowles <klk@google.com>
Authored: Thu Mar 2 14:29:56 2017 -0800
Committer: Ahmet Altay <altay@google.com>
Committed: Wed Mar 8 13:41:01 2017 -0800

----------------------------------------------------------------------
 .../dataflow/BatchStatefulParDoOverrides.java   | 283 +++++++++++++++++++
 .../dataflow/DataflowPipelineTranslator.java    |   5 +-
 .../beam/runners/dataflow/DataflowRunner.java   |   6 +
 .../beam/runners/dataflow/util/DoFnInfo.java    |   9 +
 .../BatchStatefulParDoOverridesTest.java        | 169 +++++++++++
 .../DataflowPipelineTranslatorTest.java         |  72 +++++
 .../apache/beam/sdk/transforms/ParDoTest.java   | 117 ++++++++
 7 files changed, 660 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
new file mode 100644
index 0000000..91f84ab
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow;
+
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.Iterables;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.construction.ReplacementOutputs;
+import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.InstantCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.ParDo.BoundMulti;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TaggedPValue;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.joda.time.Instant;
+
+/**
+ * {@link PTransformOverrideFactory PTransformOverrideFactories} that expands to correctly
implement
+ * stateful {@link ParDo} using window-unaware {@link GroupByKeyAndSortValuesOnly} to linearize
+ * processing per key.
+ *
+ * <p>This implementation relies on implementation details of the Dataflow runner,
specifically
+ * standard fusion behavior of {@link ParDo} tranforms following a {@link GroupByKey}.
+ */
+public class BatchStatefulParDoOverrides {
+
+  /**
+   * Returns a {@link PTransformOverrideFactory} that replaces a single-output
+   * {@link ParDo} with a composite transform specialized for the {@link DataflowRunner}.
+   */
+  public static <K, InputT, OutputT>
+      PTransformOverrideFactory<
+              PCollection<KV<K, InputT>>, PCollection<OutputT>, ParDo.Bound<KV<K,
InputT>, OutputT>>
+          singleOutputOverrideFactory() {
+    return new SingleOutputOverrideFactory<>();
+  }
+
+  /**
+   * Returns a {@link PTransformOverrideFactory} that replaces a multi-output
+   * {@link ParDo} with a composite transform specialized for the {@link DataflowRunner}.
+   */
+  public static <K, InputT, OutputT>
+      PTransformOverrideFactory<
+              PCollection<KV<K, InputT>>, PCollectionTuple,
+              ParDo.BoundMulti<KV<K, InputT>, OutputT>>
+          multiOutputOverrideFactory() {
+    return new MultiOutputOverrideFactory<>();
+  }
+
+  private static class SingleOutputOverrideFactory<K, InputT, OutputT>
+      implements PTransformOverrideFactory<
+          PCollection<KV<K, InputT>>, PCollection<OutputT>, ParDo.Bound<KV<K,
InputT>, OutputT>> {
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public PTransform<PCollection<KV<K, InputT>>, PCollection<OutputT>>
getReplacementTransform(
+        ParDo.Bound<KV<K, InputT>, OutputT> originalParDo) {
+      return new StatefulSingleOutputParDo<>(originalParDo);
+    }
+
+    @Override
+    public PCollection<KV<K, InputT>> getInput(List<TaggedPValue> inputs,
Pipeline p) {
+      return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs).getValue();
+    }
+
+    @Override
+    public Map<PValue, ReplacementOutput> mapOutputs(
+        List<TaggedPValue> outputs, PCollection<OutputT> newOutput) {
+      return ReplacementOutputs.singleton(outputs, newOutput);
+    }
+  }
+
+  private static class MultiOutputOverrideFactory<K, InputT, OutputT>
+      implements PTransformOverrideFactory<
+          PCollection<KV<K, InputT>>, PCollectionTuple, ParDo.BoundMulti<KV<K,
InputT>, OutputT>> {
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> getReplacementTransform(
+        BoundMulti<KV<K, InputT>, OutputT> originalParDo) {
+      return new StatefulMultiOutputParDo<>(originalParDo);
+    }
+
+    @Override
+    public PCollection<KV<K, InputT>> getInput(List<TaggedPValue> inputs,
Pipeline p) {
+      return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs).getValue();
+    }
+
+    @Override
+    public Map<PValue, ReplacementOutput> mapOutputs(
+        List<TaggedPValue> outputs, PCollectionTuple newOutput) {
+      return ReplacementOutputs.tagged(outputs, newOutput);
+    }
+  }
+
+  static class StatefulSingleOutputParDo<K, InputT, OutputT>
+      extends PTransform<PCollection<KV<K, InputT>>, PCollection<OutputT>>
{
+
+    private final ParDo.Bound<KV<K, InputT>, OutputT> originalParDo;
+
+    StatefulSingleOutputParDo(ParDo.Bound<KV<K, InputT>, OutputT> originalParDo)
{
+      this.originalParDo = originalParDo;
+    }
+
+    ParDo.Bound<KV<K, InputT>, OutputT> getOriginalParDo() {
+      return originalParDo;
+    }
+
+    @Override
+    public PCollection<OutputT> expand(PCollection<KV<K, InputT>> input)
{
+      DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn();
+      verifyFnIsStateful(fn);
+
+      PTransform<
+              PCollection<? extends KV<K, Iterable<KV<Instant, WindowedValue<KV<K,
InputT>>>>>>,
+              PCollection<OutputT>>
+          statefulParDo =
+              ParDo.of(new BatchStatefulDoFn<>(fn)).withSideInputs(originalParDo.getSideInputs());
+
+      return input.apply(new GbkBeforeStatefulParDo<K, InputT>()).apply(statefulParDo);
+    }
+  }
+
+  static class StatefulMultiOutputParDo<K, InputT, OutputT>
+      extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple>
{
+
+    private final BoundMulti<KV<K, InputT>, OutputT> originalParDo;
+
+    StatefulMultiOutputParDo(ParDo.BoundMulti<KV<K, InputT>, OutputT> originalParDo)
{
+      this.originalParDo = originalParDo;
+    }
+
+    @Override
+    public PCollectionTuple expand(PCollection<KV<K, InputT>> input) {
+      DoFn<KV<K, InputT>, OutputT> fn = originalParDo.getFn();
+      verifyFnIsStateful(fn);
+
+      PTransform<
+              PCollection<? extends KV<K, Iterable<KV<Instant, WindowedValue<KV<K,
InputT>>>>>>,
+              PCollectionTuple>
+          statefulParDo =
+              ParDo.of(new BatchStatefulDoFn<K, InputT, OutputT>(fn))
+                  .withSideInputs(originalParDo.getSideInputs())
+                  .withOutputTags(
+                      originalParDo.getMainOutputTag(), originalParDo.getSideOutputTags());
+
+      return input.apply(new GbkBeforeStatefulParDo<K, InputT>()).apply(statefulParDo);
+    }
+
+    public BoundMulti<KV<K, InputT>, OutputT> getOriginalParDo() {
+      return originalParDo;
+    }
+  }
+
+  static class GbkBeforeStatefulParDo<K, V>
+      extends PTransform<
+          PCollection<KV<K, V>>,
+          PCollection<KV<K, Iterable<KV<Instant, WindowedValue<KV<K, V>>>>>>>
{
+
+    @Override
+    public PCollection<KV<K, Iterable<KV<Instant, WindowedValue<KV<K, V>>>>>>
expand(
+        PCollection<KV<K, V>> input) {
+
+      WindowingStrategy<?, ?> inputWindowingStrategy = input.getWindowingStrategy();
+
+      // A KvCoder is required since this goes through GBK. Further, WindowedValueCoder
+      // is not registered by default, so we explicitly set the relevant coders.
+      checkState(
+          input.getCoder() instanceof KvCoder,
+          "Input to a %s using state requires a %s, but the coder was %s",
+          ParDo.class.getSimpleName(),
+          KvCoder.class.getSimpleName(),
+          input.getCoder());
+      KvCoder<K, V> kvCoder = (KvCoder<K, V>) input.getCoder();
+      Coder<K> keyCoder = kvCoder.getKeyCoder();
+      Coder<? extends BoundedWindow> windowCoder =
+          inputWindowingStrategy.getWindowFn().windowCoder();
+
+      return input
+          // Stash the original timestamps, etc, for when it is fed to the user's DoFn
+          .apply("ReifyWindows", ParDo.of(new ReifyWindowedValueFn<K, V>()))
+          .setCoder(
+              KvCoder.of(
+                  keyCoder,
+                  KvCoder.of(InstantCoder.of(), WindowedValue.getFullCoder(kvCoder, windowCoder))))
+
+          // Group by key and sort by timestamp, dropping windows as they are reified
+          .apply(
+              "PartitionKeys",
+              new GroupByKeyAndSortValuesOnly<K, Instant, WindowedValue<KV<K, V>>>())
+
+          // The GBKO sets the windowing strategy to the global default
+          .setWindowingStrategyInternal(inputWindowingStrategy);
+    }
+  }
+
+  /** A key-preserving {@link DoFn} that reifies a windowed value. */
+  static class ReifyWindowedValueFn<K, V>
+      extends DoFn<KV<K, V>, KV<K, KV<Instant, WindowedValue<KV<K, V>>>>>
{
+    @ProcessElement
+    public void processElement(final ProcessContext c, final BoundedWindow window) {
+      c.output(
+          KV.of(
+              c.element().getKey(),
+              KV.of(
+                  c.timestamp(), WindowedValue.of(c.element(), c.timestamp(), window, c.pane()))));
+    }
+  }
+
+  /**
+   * A key-preserving {@link DoFn} that explodes an iterable that has been grouped by key
and
+   * window.
+   */
+  public static class BatchStatefulDoFn<K, V, OutputT>
+      extends DoFn<KV<K, Iterable<KV<Instant, WindowedValue<KV<K, V>>>>>,
OutputT> {
+
+    private final DoFn<KV<K, V>, OutputT> underlyingDoFn;
+
+    BatchStatefulDoFn(DoFn<KV<K, V>, OutputT> underlyingDoFn) {
+      this.underlyingDoFn = underlyingDoFn;
+    }
+
+    public DoFn<KV<K, V>, OutputT> getUnderlyingDoFn() {
+      return underlyingDoFn;
+    }
+
+    @ProcessElement
+    public void processElement(final ProcessContext c, final BoundedWindow window) {
+      throw new UnsupportedOperationException(
+          "BatchStatefulDoFn.ProcessElement should never be invoked");
+    }
+
+    @Override
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+      return underlyingDoFn.getOutputTypeDescriptor();
+    }
+  }
+
+  private static <InputT, OutputT> void verifyFnIsStateful(DoFn<InputT, OutputT>
fn) {
+    DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
+
+    // It is still correct to use this without state or timers, but a bad idea.
+    // Since it is internal it should never be used wrong, so it is OK to crash.
+    checkState(
+        signature.usesState() || signature.usesTimers(),
+        "%s used for %s that does not use state or timers.",
+        BatchStatefulParDoOverrides.class.getSimpleName(),
+        ParDo.class.getSimpleName());
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index 7e559e9..fc47593 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -956,7 +956,10 @@ public class DataflowPipelineTranslator {
                 DoFnInfo.forFn(
                     fn, windowingStrategy, sideInputs, inputCoder, mainOutput, outputMap))));
 
-    if (signature.usesState() || signature.usesTimers()) {
+    // Setting USES_KEYED_STATE will cause an ungrouped shuffle, which works
+    // in streaming but does not work in batch
+    if (context.getPipelineOptions().isStreaming()
+        && (signature.usesState() || signature.usesTimers())) {
       stepContext.addInput(PropertyNames.USES_KEYED_STATE, "true");
     }
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 15147f1..b782786 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -341,6 +341,12 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
           UnsupportedOverrideFactory.withMessage(
               "The DataflowRunner in batch mode does not support Read.Unbounded"));
       ptoverrides
+          // State and timer pardos are implemented by expansion to GBK-then-ParDo
+          .put(PTransformMatchers.stateOrTimerParDoMulti(),
+              BatchStatefulParDoOverrides.multiOutputOverrideFactory())
+          .put(PTransformMatchers.stateOrTimerParDoSingle(),
+              BatchStatefulParDoOverrides.singleOutputOverrideFactory())
+
           // Write uses views internally
           .put(PTransformMatchers.classEqualTo(Write.class), new BatchWriteFactory(this))
           .put(

http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DoFnInfo.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DoFnInfo.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DoFnInfo.java
index 4d80a39..55c62ae 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DoFnInfo.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/DoFnInfo.java
@@ -72,6 +72,15 @@ public class DoFnInfo<InputT, OutputT> implements Serializable {
         outputMap);
   }
 
+  public DoFnInfo<InputT, OutputT> withFn(DoFn<InputT, OutputT> newFn) {
+    return DoFnInfo.forFn(newFn,
+        windowingStrategy,
+        sideInputViews,
+        inputCoder,
+        mainOutput,
+        outputMap);
+  }
+
   private DoFnInfo(
       DoFn<InputT, OutputT> doFn,
       WindowingStrategy<?, ?> windowingStrategy,

http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
new file mode 100644
index 0000000..ef3e414
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverridesTest.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.dataflow;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.ImmutableList;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.LinkedList;
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.BatchStatefulParDoOverrides.StatefulMultiOutputParDo;
+import org.apache.beam.runners.dataflow.BatchStatefulParDoOverrides.StatefulSingleOutputParDo;
+import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.GcsUtil;
+import org.apache.beam.sdk.util.TestCredential;
+import org.apache.beam.sdk.util.gcsfs.GcsPath;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/** Tests for {@link BatchStatefulParDoOverrides}. */
+@RunWith(JUnit4.class)
+public class BatchStatefulParDoOverridesTest implements Serializable {
+
+  @Test
+  public void testSingleOutputOverrideNonCrashing() throws Exception {
+    DataflowPipelineOptions options = buildPipelineOptions();
+    options.setRunner(DataflowRunner.class);
+    Pipeline pipeline = Pipeline.create(options);
+
+    DummyStatefulDoFn fn = new DummyStatefulDoFn();
+    pipeline.apply(Create.of(KV.of(1, 2))).apply(ParDo.of(fn));
+
+    DataflowRunner runner = (DataflowRunner) pipeline.getRunner();
+    runner.replaceTransforms(pipeline);
+    assertThat(findBatchStatefulDoFn(pipeline), equalTo((DoFn) fn));
+  }
+
+  @Test
+  public void testMultiOutputOverrideNonCrashing() throws Exception {
+    DataflowPipelineOptions options = buildPipelineOptions();
+    options.setRunner(DataflowRunner.class);
+    Pipeline pipeline = Pipeline.create(options);
+
+    TupleTag<Integer> mainOutputTag = new TupleTag<Integer>() {};
+
+    DummyStatefulDoFn fn = new DummyStatefulDoFn();
+    pipeline
+        .apply(Create.of(KV.of(1, 2)))
+        .apply(ParDo.withOutputTags(mainOutputTag, TupleTagList.empty()).of(fn));
+
+    DataflowRunner runner = (DataflowRunner) pipeline.getRunner();
+    runner.replaceTransforms(pipeline);
+    assertThat(findBatchStatefulDoFn(pipeline), equalTo((DoFn) fn));
+  }
+
+  private static DummyStatefulDoFn findBatchStatefulDoFn(Pipeline p) {
+    FindBatchStatefulDoFnVisitor findBatchStatefulDoFnVisitor = new FindBatchStatefulDoFnVisitor();
+    p.traverseTopologically(findBatchStatefulDoFnVisitor);
+    return (DummyStatefulDoFn) findBatchStatefulDoFnVisitor.getStatefulDoFn();
+  }
+
+  private static class DummyStatefulDoFn extends DoFn<KV<Integer, Integer>, Integer>
{
+
+    @StateId("foo")
+    private final StateSpec<Object, ValueState<Integer>> spec = StateSpecs.value(VarIntCoder.of());
+
+    @ProcessElement
+    public void processElem(ProcessContext c) {
+      // noop
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      return other instanceof DummyStatefulDoFn;
+    }
+
+    @Override
+    public int hashCode() {
+      return getClass().hashCode();
+    }
+  }
+
+  private static class FindBatchStatefulDoFnVisitor extends PipelineVisitor.Defaults {
+
+    @Nullable private DoFn<?, ?> batchStatefulDoFn;
+
+    public DoFn<?, ?> getStatefulDoFn() {
+      assertThat(batchStatefulDoFn, not(nullValue()));
+      return batchStatefulDoFn;
+    }
+
+    @Override
+    public CompositeBehavior enterCompositeTransform(Node node) {
+      if (node.getTransform() instanceof StatefulSingleOutputParDo) {
+        batchStatefulDoFn =
+            ((StatefulSingleOutputParDo) node.getTransform()).getOriginalParDo().getFn();
+        return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+      } else if (node.getTransform() instanceof StatefulMultiOutputParDo) {
+        batchStatefulDoFn =
+            ((StatefulMultiOutputParDo) node.getTransform()).getOriginalParDo().getFn();
+        return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return CompositeBehavior.ENTER_TRANSFORM;
+      }
+    }
+  }
+
+  private static DataflowPipelineOptions buildPipelineOptions() throws IOException {
+    GcsUtil mockGcsUtil = mock(GcsUtil.class);
+    when(mockGcsUtil.expand(any(GcsPath.class))).then(new Answer<List<GcsPath>>()
{
+      @Override
+      public List<GcsPath> answer(InvocationOnMock invocation) throws Throwable {
+        return ImmutableList.of((GcsPath) invocation.getArguments()[0]);
+      }
+    });
+    when(mockGcsUtil.bucketAccessible(any(GcsPath.class))).thenReturn(true);
+
+    DataflowPipelineOptions options = PipelineOptionsFactory.as(DataflowPipelineOptions.class);
+    options.setRunner(DataflowRunner.class);
+    options.setGcpCredential(new TestCredential());
+    options.setJobName("some-job-name");
+    options.setProject("some-project");
+    options.setTempLocation(GcsPath.fromComponents("somebucket", "some/path").toString());
+    options.setFilesToStage(new LinkedList<String>());
+    options.setGcsUtil(mockGcsUtil);
+    return options;
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index d4271e5..660e92e 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -22,8 +22,10 @@ import static org.apache.beam.sdk.util.Structs.getDictionary;
 import static org.apache.beam.sdk.util.Structs.getString;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasEntry;
 import static org.hamcrest.Matchers.hasKey;
+import static org.hamcrest.Matchers.not;
 import static org.hamcrest.core.IsInstanceOf.instanceOf;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
@@ -81,10 +83,15 @@ import org.apache.beam.sdk.util.Structs;
 import org.apache.beam.sdk.util.TestCredential;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.util.gcsfs.GcsPath;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
@@ -825,6 +832,71 @@ public class DataflowPipelineTranslatorTest implements Serializable {
     assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind());
   }
 
+  /**
+   * Smoke test to fail fast if translation of a stateful ParDo
+   * in batch breaks.
+   */
+  @Test
+  public void testBatchStatefulParDoTranslation() throws Exception {
+    DataflowPipelineOptions options = buildPipelineOptions();
+    DataflowRunner runner = DataflowRunner.fromOptions(options);
+    options.setStreaming(false);
+    DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options);
+
+    Pipeline pipeline = Pipeline.create(options);
+
+    TupleTag<Integer> mainOutputTag = new TupleTag<Integer>() {};
+
+    pipeline
+        .apply(Create.of(KV.of(1, 1)))
+        .apply(
+            ParDo.withOutputTags(mainOutputTag, TupleTagList.empty()).of(
+                new DoFn<KV<Integer, Integer>, Integer>() {
+                  @StateId("unused")
+                  final StateSpec<Object, ValueState<Integer>> stateSpec =
+                      StateSpecs.value(VarIntCoder.of());
+
+                  @ProcessElement
+                  public void process(ProcessContext c) {
+                    // noop
+                  }
+                }));
+
+    runner.replaceTransforms(pipeline);
+
+    Job job =
+        translator
+            .translate(
+                pipeline,
+                runner,
+                Collections.<DataflowPackage>emptyList())
+            .getJob();
+
+    // The job should look like:
+    // 0. ParallelRead (Create)
+    // 1. ParDo(ReifyWVs)
+    // 2. GroupByKeyAndSortValuesONly
+    // 3. A ParDo over grouped and sorted KVs that is executed via ungrouping service-side
+
+    List<Step> steps = job.getSteps();
+    assertEquals(4, steps.size());
+
+    Step createStep = steps.get(0);
+    assertEquals("ParallelRead", createStep.getKind());
+
+    Step reifyWindowedValueStep = steps.get(1);
+    assertEquals("ParallelDo", reifyWindowedValueStep.getKind());
+
+    Step gbkStep = steps.get(2);
+    assertEquals("GroupByKey", gbkStep.getKind());
+
+    Step statefulParDoStep = steps.get(3);
+    assertEquals("ParallelDo", statefulParDoStep.getKind());
+    assertThat(
+        (String) statefulParDoStep.getProperties().get(PropertyNames.USES_KEYED_STATE),
+        not(equalTo("true")));
+  }
+
   @Test
   public void testToSingletonTranslationWithIsmSideInput() throws Exception {
     // A "change detector" test that makes sure the translation

http://git-wip-us.apache.org/repos/asf/beam/blob/5e2afa29/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 6db0af4..e58f78e 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -48,6 +48,7 @@ import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -1506,6 +1507,54 @@ public class ParDoTest implements Serializable {
 
   @Test
   @Category({RunnableOnService.class, UsesStatefulParDo.class})
+  public void testValueStateDedup() {
+    final String stateId = "foo";
+
+    DoFn<KV<Integer, Integer>, Integer> onePerKey =
+        new DoFn<KV<Integer, Integer>, Integer>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<Integer>> seenSpec =
+              StateSpecs.value(VarIntCoder.of());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) ValueState<Integer> seenState) {
+            Integer seen = MoreObjects.firstNonNull(seenState.read(), 0);
+
+            if (seen == 0) {
+              seenState.write(seen + 1);
+              c.output(c.element().getValue());
+            }
+          }
+        };
+
+    int numKeys = 50;
+    // A big enough list that we can see some deduping
+    List<KV<Integer, Integer>> input = new ArrayList<>();
+
+    // The output should have no dupes
+    Set<Integer> expectedOutput = new HashSet<>();
+
+    for (int key = 0; key < numKeys; ++key) {
+      int output = 1000 + key;
+      expectedOutput.add(output);
+
+      for (int i = 0; i < 15; ++i) {
+        input.add(KV.of(key, output));
+      }
+    }
+
+    Collections.shuffle(input);
+
+    PCollection<Integer> output = pipeline.apply(Create.of(input)).apply(ParDo.of(onePerKey));
+
+    PAssert.that(output).containsInAnyOrder(expectedOutput);
+    pipeline.run();
+  }
+
+  @Test
+  @Category({RunnableOnService.class, UsesStatefulParDo.class})
   public void testValueStateFixedWindows() {
     final String stateId = "foo";
 
@@ -1936,6 +1985,74 @@ public class ParDoTest implements Serializable {
     pipeline.run();
   }
 
+  /**
+   * Tests that event time timers for multiple keys both fire. This particularly exercises
+   * implementations that may GC in ways not simply governed by the watermark.
+   */
+  @Test
+  @Category({RunnableOnService.class, UsesTimersInParDo.class})
+  public void testEventTimeTimerMultipleKeys() throws Exception {
+    final String timerId = "foo";
+    final String stateId = "sizzle";
+
+    final int offset = 5000;
+    final int timerOutput = 4093;
+
+    DoFn<KV<String, Integer>, KV<String, Integer>> fn =
+        new DoFn<KV<String, Integer>, KV<String, Integer>>() {
+
+          @TimerId(timerId)
+          private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<String>> stateSpec =
+              StateSpecs.value(StringUtf8Coder.of());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext context,
+              @TimerId(timerId) Timer timer,
+              @StateId(stateId) ValueState<String> state,
+              BoundedWindow window) {
+            timer.set(window.maxTimestamp());
+            state.write(context.element().getKey());
+            context.output(
+                KV.of(context.element().getKey(), context.element().getValue() + offset));
+          }
+
+          @OnTimer(timerId)
+          public void onTimer(OnTimerContext context, @StateId(stateId) ValueState<String>
state) {
+            context.output(KV.of(state.read(), timerOutput));
+          }
+        };
+
+    // Enough keys that we exercise interesting code paths
+    int numKeys = 50;
+    List<KV<String, Integer>> input = new ArrayList<>();
+    List<KV<String, Integer>> expectedOutput = new ArrayList<>();
+
+    for (Integer key = 0; key < numKeys; ++key) {
+      // Each key should have just one final output at GC time
+      expectedOutput.add(KV.of(key.toString(), timerOutput));
+
+      for (int i = 0; i < 15; ++i) {
+        // Each input should be output with the offset added
+        input.add(KV.of(key.toString(), i));
+        expectedOutput.add(KV.of(key.toString(), i + offset));
+      }
+    }
+
+    Collections.shuffle(input);
+
+    PCollection<KV<String, Integer>> output =
+        pipeline
+            .apply(
+                Create.of(input))
+            .apply(ParDo.of(fn));
+    PAssert.that(output).containsInAnyOrder(expectedOutput);
+    pipeline.run();
+  }
+
   @Test
   @Category({RunnableOnService.class, UsesTimersInParDo.class})
   public void testAbsoluteProcessingTimeTimerRejected() throws Exception {


Mime
View raw message