Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 82285200C4F for ; Sat, 1 Apr 2017 09:28:38 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 808AA160B9D; Sat, 1 Apr 2017 07:28:38 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id CF182160B8D for ; Sat, 1 Apr 2017 09:28:36 +0200 (CEST) Received: (qmail 76055 invoked by uid 500); 1 Apr 2017 07:28:36 -0000 Mailing-List: contact commits-help@beam.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@beam.apache.org Delivered-To: mailing list commits@beam.apache.org Received: (qmail 76037 invoked by uid 99); 1 Apr 2017 07:28:35 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sat, 01 Apr 2017 07:28:35 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id A9AD2DFF66; Sat, 1 Apr 2017 07:28:35 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: aviemzur@apache.org To: commits@beam.apache.org Date: Sat, 01 Apr 2017 07:28:35 -0000 Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: [1/2] beam git commit: [BEAM-1337] Infer state coders archived-at: Sat, 01 Apr 2017 07:28:38 -0000 Repository: beam Updated Branches: refs/heads/master 03dce6dcc -> e31ca8b0d [BEAM-1337] Infer state coders Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/42e690e8 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/42e690e8 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/42e690e8 Branch: refs/heads/master Commit: 42e690e84a9f05d508f2528b1444b26ce031e080 Parents: 03dce6d Author: Aviem Zur Authored: Wed Mar 1 07:27:57 2017 +0200 Committer: Aviem Zur Committed: Sat Apr 1 10:27:14 2017 +0300 ---------------------------------------------------------------------- .../org/apache/beam/sdk/transforms/ParDo.java | 62 ++ .../apache/beam/sdk/util/state/StateSpec.java | 15 + .../apache/beam/sdk/util/state/StateSpecs.java | 264 ++++++++- .../apache/beam/sdk/transforms/ParDoTest.java | 578 +++++++++++++++++++ 4 files changed, 902 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 664fbc3..3de845b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -22,6 +22,8 @@ import static com.google.common.base.Preconditions.checkArgument; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.io.Serializable; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -29,6 +31,7 @@ import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; @@ -41,6 +44,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.NameUtils; import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; @@ -434,6 +438,59 @@ public class ParDo { return DisplayData.item("fn", fn.getClass()).withLabel("Transform Function"); } + private static void finishSpecifyingStateSpecs( + DoFn fn, + CoderRegistry coderRegistry, + Coder inputCoder) { + DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); + Map stateDeclarations = signature.stateDeclarations(); + for (DoFnSignature.StateDeclaration stateDeclaration : stateDeclarations.values()) { + try { + StateSpec stateSpec = (StateSpec) stateDeclaration.field().get(fn); + stateSpec.offerCoders(codersForStateSpecTypes(stateDeclaration, coderRegistry, inputCoder)); + stateSpec.finishSpecifying(); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Try to provide coders for as many of the type arguments of given + * {@link DoFnSignature.StateDeclaration} as possible. + */ + private static Coder[] codersForStateSpecTypes( + DoFnSignature.StateDeclaration stateDeclaration, + CoderRegistry coderRegistry, + Coder inputCoder) { + Type stateType = stateDeclaration.stateType().getType(); + + if (!(stateType instanceof ParameterizedType)) { + // No type arguments means no coders to infer. + return new Coder[0]; + } + + Type[] typeArguments = ((ParameterizedType) stateType).getActualTypeArguments(); + Coder[] coders = new Coder[typeArguments.length]; + + for (int i = 0; i < typeArguments.length; i++) { + Type typeArgument = typeArguments[i]; + TypeDescriptor typeDescriptor = TypeDescriptor.of(typeArgument); + try { + coders[i] = coderRegistry.getDefaultCoder(typeDescriptor); + } catch (CannotProvideCoderException e) { + try { + coders[i] = coderRegistry.getDefaultCoder( + typeDescriptor, inputCoder.getEncodedTypeDescriptor(), inputCoder); + } catch (CannotProvideCoderException ignored) { + // Since not all type arguments will have a registered coder we ignore this exception. + } + } + } + + return coders; + } + /** * Perform common validations of the {@link DoFn} against the input {@link PCollection}, for * example ensuring that the window type expected by the {@link DoFn} matches the window type of @@ -557,6 +614,7 @@ public class ParDo { @Override public PCollection expand(PCollection input) { + finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), input.getCoder()); TupleTag mainOutput = new TupleTag<>(); return input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); } @@ -681,6 +739,10 @@ public class ParDo { public PCollectionTuple expand(PCollection input) { // SplittableDoFn should be forbidden on the runner-side. validateWindowType(input, fn); + + // Use coder registry to determine coders for all StateSpec defined in the fn signature. + finishSpecifyingStateSpecs(fn, input.getPipeline().getCoderRegistry(), input.getCoder()); + PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( input.getPipeline(), TupleTagList.of(mainOutputTag).and(sideOutputTags.getAll()), http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java index 4fdeefb..6b94c40 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpec.java @@ -20,6 +20,7 @@ package org.apache.beam.sdk.util.state; import java.io.Serializable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.coders.Coder; /** * A specification of a persistent state cell. This includes information necessary to encode the @@ -36,4 +37,18 @@ public interface StateSpec extends Serializable { * Use the {@code binder} to create an instance of {@code StateT} appropriate for this address. */ StateT bind(String id, StateBinder binder); + + /** + * Given {code coders} are inferred from type arguments defined for this class. + * Coders which are already set should take precedence over offered coders. + * @param coders Array of coders indexed by the type arguments order. Entries might be null if + * the coder could not be inferred. + */ + void offerCoders(Coder[] coders); + + /** + * Validates that this {@link StateSpec} has been specified correctly and finalizes it. + * Automatically invoked when the pipeline is built. + */ + void finishSpecifying(); } http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java index 8912993..6a8c80b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java @@ -17,7 +17,10 @@ */ package org.apache.beam.sdk.util.state; +import static com.google.common.base.Preconditions.checkArgument; + import java.util.Objects; +import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.CannotProvideCoderException; @@ -44,7 +47,13 @@ public class StateSpecs { private StateSpecs() {} /** Create a simple state spec for values of type {@code T}. */ + public static StateSpec> value() { + return new ValueStateSpec<>(null); + } + + /** Create a simple state spec for values of type {@code T}. */ public static StateSpec> value(Coder valueCoder) { + checkArgument(valueCoder != null, "valueCoder should not be null. Consider value() instead"); return new ValueStateSpec<>(valueCoder); } @@ -53,8 +62,21 @@ public class StateSpecs { * {@code InputT}s into a single {@code OutputT}. */ public static + StateSpec> combiningValue( + CombineFn combineFn) { + return new CombiningValueStateSpec(null, combineFn); + } + + /** + * Create a state spec for values that use a {@link CombineFn} to automatically merge multiple + * {@code InputT}s into a single {@code OutputT}. + */ + public static StateSpec> combiningValue( Coder accumCoder, CombineFn combineFn) { + checkArgument(accumCoder != null, + "accumCoder should not be null. " + + "Consider using combiningValue(CombineFn<> combineFn) instead."); return combiningValueInternal(accumCoder, combineFn); } @@ -63,8 +85,21 @@ public class StateSpecs { * multiple {@code InputT}s into a single {@code OutputT}. */ public static + StateSpec> keyedCombiningValue( + KeyedCombineFn combineFn) { + return new KeyedCombiningValueStateSpec(null, combineFn); + } + + /** + * Create a state spec for values that use a {@link KeyedCombineFn} to automatically merge + * multiple {@code InputT}s into a single {@code OutputT}. + */ + public static StateSpec> keyedCombiningValue( Coder accumCoder, KeyedCombineFn combineFn) { + checkArgument(accumCoder != null, + "accumCoder should not be null. " + + "Consider using keyedCombiningValue(KeyedCombineFn<> combineFn) instead."); return keyedCombiningValueInternal(accumCoder, combineFn); } @@ -73,10 +108,23 @@ public class StateSpecs { * merge multiple {@code InputT}s into a single {@code OutputT}. */ public static + StateSpec> + keyedCombiningValueWithContext(KeyedCombineFnWithContext combineFn) { + return new KeyedCombiningValueWithContextStateSpec(null, combineFn); + } + + /** + * Create a state spec for values that use a {@link KeyedCombineFnWithContext} to automatically + * merge multiple {@code InputT}s into a single {@code OutputT}. + */ + public static StateSpec> keyedCombiningValueWithContext( Coder accumCoder, KeyedCombineFnWithContext combineFn) { + checkArgument(accumCoder != null, + "accumCoder should not be null. Consider using " + + "keyedCombiningValueWithContext(KeyedCombineFnWithContext<> combineFn) instead."); return new KeyedCombiningValueWithContextStateSpec( accumCoder, combineFn); } @@ -121,8 +169,23 @@ public class StateSpecs { * Create a state spec that is optimized for adding values frequently, and occasionally retrieving * all the values that have been added. */ + public static StateSpec> bag() { + return bag(null); + } + + /** + * Create a state spec that is optimized for adding values frequently, and occasionally retrieving + * all the values that have been added. + */ public static StateSpec> bag(Coder elemCoder) { - return new BagStateSpec(elemCoder); + return new BagStateSpec<>(elemCoder); + } + + /** + * Create a state spec that supporting for {@link java.util.Set} like access patterns. + */ + public static StateSpec> set() { + return set(null); } /** @@ -135,6 +198,13 @@ public class StateSpecs { /** * Create a state spec that supporting for {@link java.util.Map} like access patterns. */ + public static StateSpec> map() { + return new MapStateSpec<>(null, null); + } + + /** + * Create a state spec that supporting for {@link java.util.Map} like access patterns. + */ public static StateSpec> map(Coder keyCoder, Coder valueCoder) { return new MapStateSpec<>(keyCoder, valueCoder); @@ -174,9 +244,10 @@ public class StateSpecs { */ private static class ValueStateSpec implements StateSpec> { - private final Coder coder; + @Nullable + private Coder coder; - private ValueStateSpec(Coder coder) { + private ValueStateSpec(@Nullable Coder coder) { this.coder = coder; } @@ -185,6 +256,25 @@ public class StateSpecs { return visitor.bindValue(id, this, coder); } + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.coder == null) { + if (coders[0] != null) { + this.coder = (Coder) coders[0]; + } + } + } + + @Override public void finishSpecifying() { + if (coder == null) { + throw new IllegalStateException("Unable to infer a coder for ValueState and no Coder" + + " was specified. Please set a coder by either invoking" + + " StateSpecs.value(Coder valueCoder) or by registering the coder in the" + + " Pipeline's CoderRegistry."); + } + } + @Override public boolean equals(Object obj) { if (obj == this) { @@ -214,15 +304,32 @@ public class StateSpecs { extends KeyedCombiningValueStateSpec implements StateSpec> { - private final Coder accumCoder; + @Nullable + private Coder accumCoder; private final CombineFn combineFn; private CombiningValueStateSpec( - Coder accumCoder, CombineFn combineFn) { + @Nullable Coder accumCoder, + CombineFn combineFn) { super(accumCoder, combineFn.asKeyedFn()); this.combineFn = combineFn; this.accumCoder = accumCoder; } + + @Override + protected Coder getAccumCoder() { + return accumCoder; + } + + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.accumCoder == null) { + if (coders[1] != null) { + this.accumCoder = (Coder) coders[1]; + } + } + } } /** @@ -234,11 +341,13 @@ public class StateSpecs { private static class KeyedCombiningValueWithContextStateSpec implements StateSpec> { - private final Coder accumCoder; + @Nullable + private Coder accumCoder; private final KeyedCombineFnWithContext combineFn; protected KeyedCombiningValueWithContextStateSpec( - Coder accumCoder, KeyedCombineFnWithContext combineFn) { + @Nullable Coder accumCoder, + KeyedCombineFnWithContext combineFn) { this.combineFn = combineFn; this.accumCoder = accumCoder; } @@ -249,6 +358,27 @@ public class StateSpecs { return visitor.bindKeyedCombiningValueWithContext(id, this, accumCoder, combineFn); } + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.accumCoder == null) { + if (coders[2] != null) { + this.accumCoder = (Coder) coders[2]; + } + } + } + + @Override public void finishSpecifying() { + if (accumCoder == null) { + throw new IllegalStateException("Unable to infer a coder for" + + " KeyedCombiningValueWithContextState and no Coder was specified." + + " Please set a coder by either invoking" + + " StateSpecs.keyedCombiningValue(Coder accumCoder," + + " KeyedCombineFn combineFn)" + + " or by registering the coder in the Pipeline's CoderRegistry."); + } + } + @Override public boolean equals(Object obj) { if (obj == this) { @@ -282,19 +412,45 @@ public class StateSpecs { private static class KeyedCombiningValueStateSpec implements StateSpec> { - private final Coder accumCoder; + @Nullable + private Coder accumCoder; private final KeyedCombineFn keyedCombineFn; protected KeyedCombiningValueStateSpec( - Coder accumCoder, KeyedCombineFn keyedCombineFn) { + @Nullable Coder accumCoder, + KeyedCombineFn keyedCombineFn) { this.keyedCombineFn = keyedCombineFn; this.accumCoder = accumCoder; } + protected Coder getAccumCoder() { + return accumCoder; + } + @Override public AccumulatorCombiningState bind( String id, StateBinder visitor) { - return visitor.bindKeyedCombiningValue(id, this, accumCoder, keyedCombineFn); + return visitor.bindKeyedCombiningValue(id, this, getAccumCoder(), keyedCombineFn); + } + + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.accumCoder == null) { + if (coders[2] != null) { + this.accumCoder = (Coder) coders[2]; + } + } + } + + @Override public void finishSpecifying() { + if (getAccumCoder() == null) { + throw new IllegalStateException("Unable to infer a coder for CombiningState and no" + + " Coder was specified. Please set a coder by either invoking" + + " StateSpecs.combiningValue(Coder accumCoder," + + " CombineFn combineFn)" + + " or by registering the coder in the Pipeline's CoderRegistry."); + } } @Override @@ -330,9 +486,10 @@ public class StateSpecs { */ private static class BagStateSpec implements StateSpec> { - private final Coder elemCoder; + @Nullable + private Coder elemCoder; - private BagStateSpec(Coder elemCoder) { + private BagStateSpec(@Nullable Coder elemCoder) { this.elemCoder = elemCoder; } @@ -341,6 +498,25 @@ public class StateSpecs { return visitor.bindBag(id, this, elemCoder); } + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.elemCoder == null) { + if (coders[0] != null) { + this.elemCoder = (Coder) coders[0]; + } + } + } + + @Override public void finishSpecifying() { + if (elemCoder == null) { + throw new IllegalStateException("Unable to infer a coder for BagState and no Coder" + + " was specified. Please set a coder by either invoking" + + " StateSpecs.bag(Coder elemCoder) or by registering the coder in the" + + " Pipeline's CoderRegistry."); + } + } + @Override public boolean equals(Object obj) { if (obj == this) { @@ -363,10 +539,12 @@ public class StateSpecs { private static class MapStateSpec implements StateSpec> { - private final Coder keyCoder; - private final Coder valueCoder; + @Nullable + private Coder keyCoder; + @Nullable + private Coder valueCoder; - private MapStateSpec(Coder keyCoder, Coder valueCoder) { + private MapStateSpec(@Nullable Coder keyCoder, @Nullable Coder valueCoder) { this.keyCoder = keyCoder; this.valueCoder = valueCoder; } @@ -376,6 +554,30 @@ public class StateSpecs { return visitor.bindMap(id, this, keyCoder, valueCoder); } + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.keyCoder == null) { + if (coders[0] != null) { + this.keyCoder = (Coder) coders[0]; + } + } + if (this.valueCoder == null) { + if (coders[1] != null) { + this.valueCoder = (Coder) coders[1]; + } + } + } + + @Override public void finishSpecifying() { + if (keyCoder == null || valueCoder == null) { + throw new IllegalStateException("Unable to infer a coder for MapState and no Coder" + + " was specified. Please set a coder by either invoking" + + " StateSpecs.map(Coder keyCoder, Coder valueCoder) or by registering the" + + " coder in the Pipeline's CoderRegistry."); + } + } + @Override public boolean equals(Object obj) { if (obj == this) { @@ -404,9 +606,10 @@ public class StateSpecs { */ private static class SetStateSpec implements StateSpec> { - private final Coder elemCoder; + @Nullable + private Coder elemCoder; - private SetStateSpec(Coder elemCoder) { + private SetStateSpec(@Nullable Coder elemCoder) { this.elemCoder = elemCoder; } @@ -415,6 +618,25 @@ public class StateSpecs { return visitor.bindSet(id, this, elemCoder); } + @SuppressWarnings("unchecked") + @Override + public void offerCoders(Coder[] coders) { + if (this.elemCoder == null) { + if (coders[0] != null) { + this.elemCoder = (Coder) coders[0]; + } + } + } + + @Override public void finishSpecifying() { + if (elemCoder == null) { + throw new IllegalStateException("Unable to infer a coder for SetState and no Coder" + + " was specified. Please set a coder by either invoking" + + " StateSpecs.set(Coder elemCoder) or by registering the coder in the" + + " Pipeline's CoderRegistry."); + } + } + @Override public boolean equals(Object obj) { if (obj == this) { @@ -461,6 +683,14 @@ public class StateSpecs { } @Override + public void offerCoders(Coder[] coders) { + } + + @Override public void finishSpecifying() { + // Currently an empty implementation as there are no coders to validate. + } + + @Override public boolean equals(Object obj) { if (obj == this) { return true; http://git-wip-us.apache.org/repos/asf/beam/blob/42e690e8/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index cbbbe5f..4249a77 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -39,6 +39,7 @@ import static org.junit.Assert.assertThat; import com.fasterxml.jackson.annotation.JsonCreator; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -55,8 +56,12 @@ import java.util.Map; import java.util.Set; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.SetCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.CountingInput; @@ -1036,6 +1041,71 @@ public class ParDoTest implements Serializable { } } + private static class MyInteger implements Comparable { + private final int value; + + MyInteger(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof MyInteger)) { + return false; + } + + MyInteger myInteger = (MyInteger) o; + + return value == myInteger.value; + + } + + @Override + public int hashCode() { + return value; + } + + @Override + public int compareTo(MyInteger o) { + return Integer.compare(this.getValue(), o.getValue()); + } + + @Override + public String toString() { + return "MyInteger{" + "value=" + value + '}'; + } + } + + private static class MyIntegerCoder extends CustomCoder { + private static final MyIntegerCoder INSTANCE = new MyIntegerCoder(); + + private final VarIntCoder delegate = VarIntCoder.of(); + + public static MyIntegerCoder of() { + return INSTANCE; + } + + @Override + public void encode(MyInteger value, OutputStream outStream, Context context) + throws CoderException, IOException { + delegate.encode(value.getValue(), outStream, context); + } + + @Override + public MyInteger decode(InputStream inStream, Context context) throws CoderException, + IOException { + return new MyInteger(delegate.decode(inStream, context)); + } + } + /** PAssert "matcher" for expected output. */ static class HasExpectedOutput implements SerializableFunction, Void>, Serializable { @@ -1619,6 +1689,132 @@ public class ParDoTest implements Serializable { @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testValueStateCoderInference() { + final String stateId = "foo"; + MyIntegerCoder myIntegerCoder = MyIntegerCoder.of(); + pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder); + + DoFn, MyInteger> fn = + new DoFn, MyInteger>() { + + @StateId(stateId) + private final StateSpec> intState = + StateSpecs.value(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState state) { + MyInteger currentValue = MoreObjects.firstNonNull(state.read(), new MyInteger(0)); + c.output(currentValue); + state.write(new MyInteger(currentValue.getValue() + 1)); + } + }; + + PCollection output = + pipeline.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), KV.of("hello", 84))) + .apply(ParDo.of(fn)).setCoder(myIntegerCoder); + + PAssert.that(output).containsInAnyOrder(new MyInteger(0), new MyInteger(1), new MyInteger(2)); + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testValueStateCoderInferenceFailure() throws Exception { + final String stateId = "foo"; + MyIntegerCoder myIntegerCoder = MyIntegerCoder.of(); + + DoFn, MyInteger> fn = + new DoFn, MyInteger>() { + + @StateId(stateId) + private final StateSpec> intState = + StateSpecs.value(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState state) { + MyInteger currentValue = MoreObjects.firstNonNull(state.read(), new MyInteger(0)); + c.output(currentValue); + state.write(new MyInteger(currentValue.getValue() + 1)); + } + }; + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Unable to infer a coder for ValueState and no Coder was specified."); + + pipeline.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), KV.of("hello", 84))) + .apply(ParDo.of(fn)).setCoder(myIntegerCoder); + + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testValueStateCoderInferenceFromInputCoder() { + final String stateId = "foo"; + MyIntegerCoder myIntegerCoder = MyIntegerCoder.of(); + + DoFn, MyInteger> fn = + new DoFn, MyInteger>() { + + @StateId(stateId) + private final StateSpec> intState = + StateSpecs.value(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState state) { + MyInteger currentValue = MoreObjects.firstNonNull(state.read(), new MyInteger(0)); + c.output(currentValue); + state.write(new MyInteger(currentValue.getValue() + 1)); + } + }; + + pipeline + .apply(Create.of(KV.of("hello", new MyInteger(42)), + KV.of("hello", new MyInteger(97)), KV.of("hello", new MyInteger(84))) + .withCoder(KvCoder.of(StringUtf8Coder.of(), myIntegerCoder))) + .apply(ParDo.of(fn)).setCoder(myIntegerCoder); + + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testCoderInferenceOfList() { + final String stateId = "foo"; + MyIntegerCoder myIntegerCoder = MyIntegerCoder.of(); + pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder); + + DoFn, List> fn = + new DoFn, List>() { + + @StateId(stateId) + private final StateSpec>> intState = + StateSpecs.value(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) ValueState> state) { + MyInteger myInteger = new MyInteger(c.element().getValue()); + List currentValue = state.read(); + List newValue = currentValue != null + ? ImmutableList.builder().addAll(currentValue).add(myInteger).build() + : Collections.singletonList(myInteger); + c.output(newValue); + state.write(newValue); + } + }; + + pipeline.apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), KV.of("hello", 84))) + .apply(ParDo.of(fn)).setCoder(ListCoder.of(myIntegerCoder)); + + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) public void testValueStateFixedWindows() { final String stateId = "foo"; @@ -1801,6 +1997,82 @@ public class ParDoTest implements Serializable { } @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testBagStateCoderInference() { + final String stateId = "foo"; + Coder myIntegerCoder = MyIntegerCoder.of(); + pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder); + + DoFn, List> fn = + new DoFn, List>() { + + @StateId(stateId) + private final StateSpec> bufferState = + StateSpecs.bag(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) BagState state) { + Iterable currentValue = state.read(); + state.add(new MyInteger(c.element().getValue())); + if (Iterables.size(state.read()) >= 4) { + List sorted = Lists.newArrayList(currentValue); + Collections.sort(sorted); + c.output(sorted); + } + } + }; + + PCollection> output = + pipeline.apply( + Create.of( + KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 84), KV.of("hello", 12))) + .apply(ParDo.of(fn)).setCoder(ListCoder.of(myIntegerCoder)); + + PAssert.that(output).containsInAnyOrder(Lists.newArrayList(new MyInteger(12), new MyInteger(42), + new MyInteger(84), new MyInteger(97))); + + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testBagStateCoderInferenceFailure() throws Exception { + final String stateId = "foo"; + Coder myIntegerCoder = MyIntegerCoder.of(); + + DoFn, List> fn = + new DoFn, List>() { + + @StateId(stateId) + private final StateSpec> bufferState = + StateSpecs.bag(); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) BagState state) { + Iterable currentValue = state.read(); + state.add(new MyInteger(c.element().getValue())); + if (Iterables.size(state.read()) >= 4) { + List sorted = Lists.newArrayList(currentValue); + Collections.sort(sorted); + c.output(sorted); + } + } + }; + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Unable to infer a coder for BagState and no Coder was specified."); + + pipeline.apply( + Create.of( + KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 84), KV.of("hello", 12))) + .apply(ParDo.of(fn)).setCoder(ListCoder.of(myIntegerCoder)); + + pipeline.run(); + } + + @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesSetState.class}) public void testSetState() { final String stateId = "foo"; @@ -1843,6 +2115,93 @@ public class ParDoTest implements Serializable { } @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesSetState.class}) + public void testSetStateCoderInference() { + final String stateId = "foo"; + final String countStateId = "count"; + Coder myIntegerCoder = MyIntegerCoder.of(); + pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder); + + DoFn, Set> fn = + new DoFn, Set>() { + + @StateId(stateId) + private final StateSpec> setState = StateSpecs.set(); + + @StateId(countStateId) + private final StateSpec> + countState = StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(), + Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, + @StateId(stateId) SetState state, + @StateId(countStateId) AccumulatorCombiningState count) { + state.add(new MyInteger(c.element().getValue())); + count.add(1); + if (count.read() >= 4) { + Set set = Sets.newHashSet(state.read()); + c.output(set); + } + } + }; + + PCollection> output = + pipeline.apply( + Create.of( + KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), KV.of("hello", 12))) + .apply(ParDo.of(fn)).setCoder(SetCoder.of(myIntegerCoder)); + + PAssert.that(output).containsInAnyOrder( + Sets.newHashSet(new MyInteger(97), new MyInteger(42), new MyInteger(12))); + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesSetState.class}) + public void testSetStateCoderInferenceFailure() throws Exception { + final String stateId = "foo"; + final String countStateId = "count"; + Coder myIntegerCoder = MyIntegerCoder.of(); + + DoFn, Set> fn = + new DoFn, Set>() { + + @StateId(stateId) + private final StateSpec> setState = StateSpecs.set(); + + @StateId(countStateId) + private final StateSpec> + countState = StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(), + Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, + @StateId(stateId) SetState state, + @StateId(countStateId) AccumulatorCombiningState count) { + state.add(new MyInteger(c.element().getValue())); + count.add(1); + if (count.read() >= 4) { + Set set = Sets.newHashSet(state.read()); + c.output(set); + } + } + }; + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Unable to infer a coder for SetState and no Coder was specified."); + + pipeline.apply( + Create.of( + KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), KV.of("hello", 12))) + .apply(ParDo.of(fn)).setCoder(SetCoder.of(myIntegerCoder)); + + pipeline.run(); + } + + @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMapState.class}) public void testMapState() { final String stateId = "foo"; @@ -1888,6 +2247,99 @@ public class ParDoTest implements Serializable { } @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMapState.class}) + public void testMapStateCoderInference() { + final String stateId = "foo"; + final String countStateId = "count"; + Coder myIntegerCoder = MyIntegerCoder.of(); + pipeline.getCoderRegistry().registerCoder(MyInteger.class, myIntegerCoder); + + DoFn>, KV> fn = + new DoFn>, KV>() { + + @StateId(stateId) + private final StateSpec> mapState = StateSpecs.map(); + @StateId(countStateId) + private final StateSpec> + countState = StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(), + Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) MapState state, + @StateId(countStateId) AccumulatorCombiningState + count) { + KV value = c.element().getValue(); + state.put(value.getKey(), new MyInteger(value.getValue())); + count.add(1); + if (count.read() >= 4) { + Iterable> iterate = state.iterate(); + for (Map.Entry entry : iterate) { + c.output(KV.of(entry.getKey(), entry.getValue())); + } + } + } + }; + + PCollection> output = + pipeline.apply( + Create.of( + KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 12)))) + .apply(ParDo.of(fn)).setCoder(KvCoder.of(StringUtf8Coder.of(), myIntegerCoder)); + + PAssert.that(output).containsInAnyOrder(KV.of("a", new MyInteger(97)), + KV.of("b", new MyInteger(42)), KV.of("c", new MyInteger(12))); + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMapState.class}) + public void testMapStateCoderInferenceFailure() throws Exception { + final String stateId = "foo"; + final String countStateId = "count"; + Coder myIntegerCoder = MyIntegerCoder.of(); + + DoFn>, KV> fn = + new DoFn>, KV>() { + + @StateId(stateId) + private final StateSpec> mapState = StateSpecs.map(); + @StateId(countStateId) + private final StateSpec> + countState = StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(), + Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) MapState state, + @StateId(countStateId) AccumulatorCombiningState + count) { + KV value = c.element().getValue(); + state.put(value.getKey(), new MyInteger(value.getValue())); + count.add(1); + if (count.read() >= 4) { + Iterable> iterate = state.iterate(); + for (Map.Entry entry : iterate) { + c.output(KV.of(entry.getKey(), entry.getValue())); + } + } + } + }; + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Unable to infer a coder for MapState and no Coder was specified."); + + pipeline.apply( + Create.of( + KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 12)))) + .apply(ParDo.of(fn)).setCoder(KvCoder.of(StringUtf8Coder.of(), myIntegerCoder)); + + pipeline.run(); + } + + @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class}) public void testCombiningState() { final String stateId = "foo"; @@ -1928,6 +2380,132 @@ public class ParDoTest implements Serializable { @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testCombiningStateCoderInference() { + pipeline.getCoderRegistry().registerCoder(MyInteger.class, MyIntegerCoder.of()); + + final String stateId = "foo"; + + DoFn, String> fn = + new DoFn, String>() { + private static final int EXPECTED_SUM = 16; + + @StateId(stateId) + private final StateSpec< + Object, AccumulatorCombiningState> + combiningState = + StateSpecs.combiningValue(new Combine.CombineFn() { + @Override + public MyInteger createAccumulator() { + return new MyInteger(0); + } + + @Override + public MyInteger addInput(MyInteger accumulator, Integer input) { + return new MyInteger(accumulator.getValue() + input); + } + + @Override + public MyInteger mergeAccumulators(Iterable accumulators) { + int newValue = 0; + for (MyInteger myInteger : accumulators) { + newValue += myInteger.getValue(); + } + return new MyInteger(newValue); + } + + @Override + public Integer extractOutput(MyInteger accumulator) { + return accumulator.getValue(); + } + }); + + @ProcessElement + public void processElement( + ProcessContext c, + @StateId(stateId) + AccumulatorCombiningState state) { + state.add(c.element().getValue()); + Integer currentValue = state.read(); + if (currentValue == EXPECTED_SUM) { + c.output("right on"); + } + } + }; + + PCollection output = + pipeline + .apply(Create.of(KV.of("hello", 3), KV.of("hello", 6), KV.of("hello", 7))) + .apply(ParDo.of(fn)); + + // There should only be one moment at which the average is exactly 16 + PAssert.that(output).containsInAnyOrder("right on"); + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) + public void testCombiningStateCoderInferenceFailure() throws Exception { + final String stateId = "foo"; + + DoFn, String> fn = + new DoFn, String>() { + private static final int EXPECTED_SUM = 16; + + @StateId(stateId) + private final StateSpec< + Object, AccumulatorCombiningState> + combiningState = + StateSpecs.combiningValue(new Combine.CombineFn() { + @Override + public MyInteger createAccumulator() { + return new MyInteger(0); + } + + @Override + public MyInteger addInput(MyInteger accumulator, Integer input) { + return new MyInteger(accumulator.getValue() + input); + } + + @Override + public MyInteger mergeAccumulators(Iterable accumulators) { + int newValue = 0; + for (MyInteger myInteger : accumulators) { + newValue += myInteger.getValue(); + } + return new MyInteger(newValue); + } + + @Override + public Integer extractOutput(MyInteger accumulator) { + return accumulator.getValue(); + } + }); + + @ProcessElement + public void processElement( + ProcessContext c, + @StateId(stateId) + AccumulatorCombiningState state) { + state.add(c.element().getValue()); + Integer currentValue = state.read(); + if (currentValue == EXPECTED_SUM) { + c.output("right on"); + } + } + }; + + thrown.expect(RuntimeException.class); + thrown.expectMessage("Unable to infer a coder for CombiningState and no Coder was specified."); + + pipeline + .apply(Create.of(KV.of("hello", 3), KV.of("hello", 6), KV.of("hello", 7))) + .apply(ParDo.of(fn)); + + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class}) public void testBagStateSideInput() { final PCollectionView> listView =