beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bchamb...@apache.org
Subject [1/2] incubator-beam git commit: [BEAM-96] Add composed `CombineFn` builders in `CombineFns`
Date Thu, 17 Mar 2016 21:17:08 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/master ac63fd6d4 -> c30326007


[BEAM-96] Add composed `CombineFn` builders in `CombineFns`

* `compose()` or `composeKeyed()` are used to start composition
* `with()` is used to add an input-transformation, a `CombineFn`
  and an output `TupleTag`.
* A non-`CombineFn` initial builder is used to ensure that every
  composition includes at least one item
* Duplicate output tags are not allowed in the same composition


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/23b43780
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/23b43780
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/23b43780

Branch: refs/heads/master
Commit: 23b437802546f32a167b38f8d0bc7a566abde224
Parents: ac63fd6
Author: Pei He <peihe@google.com>
Authored: Fri Mar 4 13:54:34 2016 -0800
Committer: bchambers <bchambers@google.com>
Committed: Thu Mar 17 13:54:40 2016 -0700

----------------------------------------------------------------------
 .../dataflow/sdk/transforms/CombineFns.java     | 1100 ++++++++++++++++++
 .../dataflow/sdk/transforms/CombineFnsTest.java |  413 +++++++
 2 files changed, 1513 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/23b43780/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java
new file mode 100644
index 0000000..656c010
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java
@@ -0,0 +1,1100 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.transforms;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
+import com.google.cloud.dataflow.sdk.coders.StandardCoder;
+import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn;
+import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext;
+import com.google.cloud.dataflow.sdk.util.PropertyNames;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/**
+ * Static utility methods that create combine function instances.
+ */
+public class CombineFns {
+
+  /**
+   * Returns a {@link ComposeKeyedCombineFnBuilder} to construct a composed
+   * {@link PerKeyCombineFn}.
+   *
+   * <p>The same {@link TupleTag} cannot be used in a composition multiple times.
+   *
+   * <p>Example:
+   * <pre>{ @code
+   * PCollection<KV<K, Integer>> latencies = ...;
+   *
+   * TupleTag<Integer> maxLatencyTag = new TupleTag<Integer>();
+   * TupleTag<Double> meanLatencyTag = new TupleTag<Double>();
+   *
+   * SimpleFunction<Integer, Integer> identityFn =
+   *     new SimpleFunction<Integer, Integer>() {
+   *       @Override
+   *       public Integer apply(Integer input) {
+   *           return input;
+   *       }};
+   * PCollection<KV<K, CoCombineResult>> maxAndMean = latencies.apply(
+   *     Combine.perKey(
+   *         CombineFns.composeKeyed()
+   *            .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+   *            .with(identityFn, new MeanFn<Integer>(), meanLatencyTag)));
+   *
+   * PCollection<T> finalResultCollection = maxAndMean
+   *     .apply(ParDo.of(
+   *         new DoFn<KV<K, CoCombineResult>, T>() {
+   *           @Override
+   *           public void processElement(ProcessContext c) throws Exception {
+   *             KV<K, CoCombineResult> e = c.element();
+   *             Integer maxLatency = e.getValue().get(maxLatencyTag);
+   *             Double meanLatency = e.getValue().get(meanLatencyTag);
+   *             .... Do Something ....
+   *             c.output(...some T...);
+   *           }
+   *         }));
+   * } </pre>
+   */
+  public static ComposeKeyedCombineFnBuilder composeKeyed() {
+    return new ComposeKeyedCombineFnBuilder();
+  }
+
+  /**
+   * Returns a {@link ComposeCombineFnBuilder} to construct a composed
+   * {@link GlobalCombineFn}.
+   *
+   * <p>The same {@link TupleTag} cannot be used in a composition multiple times.
+   *
+   * <p>Example:
+   * <pre>{ @code
+   * PCollection<Integer> globalLatencies = ...;
+   *
+   * TupleTag<Integer> maxLatencyTag = new TupleTag<Integer>();
+   * TupleTag<Double> meanLatencyTag = new TupleTag<Double>();
+   *
+   * SimpleFunction<Integer, Integer> identityFn =
+   *     new SimpleFunction<Integer, Integer>() {
+   *       @Override
+   *       public Integer apply(Integer input) {
+   *           return input;
+   *       }};
+   * PCollection<CoCombineResult> maxAndMean = globalLatencies.apply(
+   *     Combine.globally(
+   *         CombineFns.compose()
+   *            .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+   *            .with(identityFn, new MeanFn<Integer>(), meanLatencyTag)));
+   *
+   * PCollection<T> finalResultCollection = maxAndMean
+   *     .apply(ParDo.of(
+   *         new DoFn<CoCombineResult, T>() {
+   *           @Override
+   *           public void processElement(ProcessContext c) throws Exception {
+   *             CoCombineResult e = c.element();
+   *             Integer maxLatency = e.get(maxLatencyTag);
+   *             Double meanLatency = e.get(meanLatencyTag);
+   *             .... Do Something ....
+   *             c.output(...some T...);
+   *           }
+   *         }));
+   * } </pre>
+   */
+  public static ComposeCombineFnBuilder compose() {
+    return new ComposeCombineFnBuilder();
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  /**
+   * A builder class to construct a composed {@link PerKeyCombineFn}.
+   */
+  public static class ComposeKeyedCombineFnBuilder {
+    /**
+     * Returns a {@link ComposedKeyedCombineFn} that can take additional
+     * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+     *
+     * <p>The {@link ComposedKeyedCombineFn} extracts inputs from {@code DataT} with
+     * the {@code extractInputFn} and combines them with the {@code keyedCombineFn},
+     * and then it outputs each combined value with a {@link TupleTag} to a
+     * {@link CoCombineResult}.
+     */
+    public <K, DataT, InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        KeyedCombineFn<K, InputT, ?, OutputT> keyedCombineFn,
+        TupleTag<OutputT> outputTag) {
+      return new ComposedKeyedCombineFn<DataT, K>()
+          .with(extractInputFn, keyedCombineFn, outputTag);
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional
+     * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+     *
+     * <p>The {@link ComposedKeyedCombineFnWithContext} extracts inputs from {@code DataT} with
+     * the {@code extractInputFn} and combines them with the {@code keyedCombineFnWithContext},
+     * and then it outputs each combined value with a {@link TupleTag} to a
+     * {@link CoCombineResult}.
+     */
+    public <K, DataT, InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        KeyedCombineFnWithContext<K, InputT, ?, OutputT> keyedCombineFnWithContext,
+        TupleTag<OutputT> outputTag) {
+      return new ComposedKeyedCombineFnWithContext<DataT, K>()
+          .with(extractInputFn, keyedCombineFnWithContext, outputTag);
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFn} that can take additional
+     * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+     */
+    public <K, DataT, InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFn<InputT, ?, OutputT> combineFn,
+        TupleTag<OutputT> outputTag) {
+      return with(extractInputFn, combineFn.<K>asKeyedFn(), outputTag);
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional
+     * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+     */
+    public <K, DataT, InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFnWithContext<InputT, ?, OutputT> combineFnWithContext,
+        TupleTag<OutputT> outputTag) {
+      return with(extractInputFn, combineFnWithContext.<K>asKeyedFn(), outputTag);
+    }
+  }
+
+  /**
+   * A builder class to construct a composed {@link GlobalCombineFn}.
+   */
+  public static class ComposeCombineFnBuilder {
+    /**
+     * Returns a {@link ComposedCombineFn} that can take additional
+     * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function.
+     *
+     * <p>The {@link ComposedCombineFn} extracts inputs from {@code DataT} with
+     * the {@code extractInputFn} and combines them with the {@code combineFn},
+     * and then it outputs each combined value with a {@link TupleTag} to a
+     * {@link CoCombineResult}.
+     */
+    public <DataT, InputT, OutputT> ComposedCombineFn<DataT> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFn<InputT, ?, OutputT> combineFn,
+        TupleTag<OutputT> outputTag) {
+      return new ComposedCombineFn<DataT>()
+          .with(extractInputFn, combineFn, outputTag);
+    }
+
+    /**
+     * Returns a {@link ComposedCombineFnWithContext} that can take additional
+     * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function.
+     *
+     * <p>The {@link ComposedCombineFnWithContext} extracts inputs from {@code DataT} with
+     * the {@code extractInputFn} and combines them with the {@code combineFnWithContext},
+     * and then it outputs each combined value with a {@link TupleTag} to a
+     * {@link CoCombineResult}.
+     */
+    public <DataT, InputT, OutputT> ComposedCombineFnWithContext<DataT> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFnWithContext<InputT, ?, OutputT> combineFnWithContext,
+        TupleTag<OutputT> outputTag) {
+      return new ComposedCombineFnWithContext<DataT>()
+          .with(extractInputFn, combineFnWithContext, outputTag);
+    }
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  /**
+   * A tuple of outputs produced by a composed combine functions.
+   *
+   * <p>See {@link #compose()} or {@link #composeKeyed()}) for details.
+   */
+  public static class CoCombineResult implements Serializable {
+
+    private enum NullValue {
+      INSTANCE;
+    }
+
+    private final Map<TupleTag<?>, Object> valuesMap;
+
+    /**
+     * The constructor of {@link CoCombineResult}.
+     *
+     * <p>Null values should have been filtered out from the {@code valuesMap}.
+     * {@link TupleTag TupleTags} that associate with null values doesn't exist in the key set of
+     * {@code valuesMap}.
+     *
+     * @throws NullPointerException if any key or value in {@code valuesMap} is null
+     */
+    CoCombineResult(Map<TupleTag<?>, Object> valuesMap) {
+      ImmutableMap.Builder<TupleTag<?>, Object> builder = ImmutableMap.builder();
+      for (Entry<TupleTag<?>, Object> entry : valuesMap.entrySet()) {
+        if (entry.getValue() != null) {
+          builder.put(entry);
+        } else {
+          builder.put(entry.getKey(), NullValue.INSTANCE);
+        }
+      }
+      this.valuesMap = builder.build();
+    }
+
+    /**
+     * Returns the value represented by the given {@link TupleTag}.
+     *
+     * <p>It is an error to request a non-exist tuple tag from the {@link CoCombineResult}.
+     */
+    @SuppressWarnings("unchecked")
+    public <V> V get(TupleTag<V> tag) {
+      checkArgument(
+          valuesMap.keySet().contains(tag), "TupleTag " + tag + " is not in the CoCombineResult");
+      Object value = valuesMap.get(tag);
+      if (value == NullValue.INSTANCE) {
+        return null;
+      } else {
+        return (V) value;
+      }
+    }
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  /**
+   * A composed {@link CombineFn} that applies multiple {@link CombineFn CombineFns}.
+   *
+   * <p>For each {@link CombineFn} it extracts inputs from {@code DataT} with
+   * the {@code extractInputFn} and combines them,
+   * and then it outputs each combined value with a {@link TupleTag} to a
+   * {@link CoCombineResult}.
+   */
+  public static class ComposedCombineFn<DataT> extends CombineFn<DataT, Object[], CoCombineResult> {
+
+    private final List<CombineFn<Object, Object, Object>> combineFns;
+    private final List<SerializableFunction<DataT, Object>> extractInputFns;
+    private final List<TupleTag<?>> outputTags;
+    private final int combineFnCount;
+
+    private ComposedCombineFn() {
+      this.extractInputFns = ImmutableList.of();
+      this.combineFns = ImmutableList.of();
+      this.outputTags = ImmutableList.of();
+      this.combineFnCount = 0;
+    }
+
+    private ComposedCombineFn(
+        ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+        ImmutableList<CombineFn<?, ?, ?>> combineFns,
+        ImmutableList<TupleTag<?>> outputTags) {
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<SerializableFunction<DataT, Object>> castedExtractInputFns = (List) extractInputFns;
+      this.extractInputFns = castedExtractInputFns;
+
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<CombineFn<Object, Object, Object>> castedCombineFns = (List) combineFns;
+      this.combineFns = castedCombineFns;
+
+      this.outputTags = outputTags;
+      this.combineFnCount = this.combineFns.size();
+    }
+
+    /**
+     * Returns a {@link ComposedCombineFn} with an additional {@link CombineFn}.
+     */
+    public <InputT, OutputT> ComposedCombineFn<DataT> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFn<InputT, ?, OutputT> combineFn,
+        TupleTag<OutputT> outputTag) {
+      checkUniqueness(outputTags, outputTag);
+      return new ComposedCombineFn<>(
+          ImmutableList.<SerializableFunction<DataT, ?>>builder()
+              .addAll(extractInputFns)
+              .add(extractInputFn)
+              .build(),
+          ImmutableList.<CombineFn<?, ?, ?>>builder()
+              .addAll(combineFns)
+              .add(combineFn)
+              .build(),
+          ImmutableList.<TupleTag<?>>builder()
+              .addAll(outputTags)
+              .add(outputTag)
+              .build());
+    }
+
+    /**
+     * Returns a {@link ComposedCombineFnWithContext} with an additional
+     * {@link CombineFnWithContext}.
+     */
+    public <InputT, OutputT> ComposedCombineFnWithContext<DataT> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFnWithContext<InputT, ?, OutputT> combineFn,
+        TupleTag<OutputT> outputTag) {
+      checkUniqueness(outputTags, outputTag);
+      List<CombineFnWithContext<Object, Object, Object>> fnsWithContext = Lists.newArrayList();
+      for (CombineFn<Object, Object, Object> fn : combineFns) {
+        fnsWithContext.add(toFnWithContext(fn));
+      }
+      return new ComposedCombineFnWithContext<>(
+          ImmutableList.<SerializableFunction<DataT, ?>>builder()
+              .addAll(extractInputFns)
+              .add(extractInputFn)
+              .build(),
+          ImmutableList.<CombineFnWithContext<?, ?, ?>>builder()
+              .addAll(fnsWithContext)
+              .add(combineFn)
+              .build(),
+          ImmutableList.<TupleTag<?>>builder()
+              .addAll(outputTags)
+              .add(outputTag)
+              .build());
+    }
+
+    @Override
+    public Object[] createAccumulator() {
+      Object[] accumsArray = new Object[combineFnCount];
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumsArray[i] = combineFns.get(i).createAccumulator();
+      }
+      return accumsArray;
+    }
+
+    @Override
+    public Object[] addInput(Object[] accumulator, DataT value) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        Object input = extractInputFns.get(i).apply(value);
+        accumulator[i] = combineFns.get(i).addInput(accumulator[i], input);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Object[] mergeAccumulators(Iterable<Object[]> accumulators) {
+      Iterator<Object[]> iter = accumulators.iterator();
+      if (!iter.hasNext()) {
+        return createAccumulator();
+      } else {
+        // Reuses the first accumulator, and overwrites its values.
+        // It is safe because {@code accum[i]} only depends on
+        // the i-th component of each accumulator.
+        Object[] accum = iter.next();
+        for (int i = 0; i < combineFnCount; ++i) {
+          accum[i] = combineFns.get(i).mergeAccumulators(new ProjectionIterable(accumulators, i));
+        }
+        return accum;
+      }
+    }
+
+    @Override
+    public CoCombineResult extractOutput(Object[] accumulator) {
+      Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+      for (int i = 0; i < combineFnCount; ++i) {
+        valuesMap.put(
+            outputTags.get(i),
+            combineFns.get(i).extractOutput(accumulator[i]));
+      }
+      return new CoCombineResult(valuesMap);
+    }
+
+    @Override
+    public Object[] compact(Object[] accumulator) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumulator[i] = combineFns.get(i).compact(accumulator[i]);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<DataT> dataCoder)
+        throws CannotProvideCoderException {
+      List<Coder<Object>> coders = Lists.newArrayList();
+      for (int i = 0; i < combineFnCount; ++i) {
+        Coder<Object> inputCoder =
+            registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+        coders.add(combineFns.get(i).getAccumulatorCoder(registry, inputCoder));
+      }
+      return new ComposedAccumulatorCoder(coders);
+    }
+  }
+
+  /**
+   * A composed {@link CombineFnWithContext} that applies multiple
+   * {@link CombineFnWithContext CombineFnWithContexts}.
+   *
+   * <p>For each {@link CombineFnWithContext} it extracts inputs from {@code DataT} with
+   * the {@code extractInputFn} and combines them,
+   * and then it outputs each combined value with a {@link TupleTag} to a
+   * {@link CoCombineResult}.
+   */
+  public static class ComposedCombineFnWithContext<DataT>
+      extends CombineFnWithContext<DataT, Object[], CoCombineResult> {
+
+    private final List<SerializableFunction<DataT, Object>> extractInputFns;
+    private final List<CombineFnWithContext<Object, Object, Object>> combineFnWithContexts;
+    private final List<TupleTag<?>> outputTags;
+    private final int combineFnCount;
+
+    private ComposedCombineFnWithContext() {
+      this.extractInputFns = ImmutableList.of();
+      this.combineFnWithContexts = ImmutableList.of();
+      this.outputTags = ImmutableList.of();
+      this.combineFnCount = 0;
+    }
+
+    private ComposedCombineFnWithContext(
+        ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+        ImmutableList<CombineFnWithContext<?, ?, ?>> combineFnWithContexts,
+        ImmutableList<TupleTag<?>> outputTags) {
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<SerializableFunction<DataT, Object>> castedExtractInputFns =
+          (List) extractInputFns;
+      this.extractInputFns = castedExtractInputFns;
+
+      @SuppressWarnings({"rawtypes", "unchecked"})
+      List<CombineFnWithContext<Object, Object, Object>> castedCombineFnWithContexts
+          = (List) combineFnWithContexts;
+      this.combineFnWithContexts = castedCombineFnWithContexts;
+
+      this.outputTags = outputTags;
+      this.combineFnCount = this.combineFnWithContexts.size();
+    }
+
+    /**
+     * Returns a {@link ComposedCombineFnWithContext} with an additional {@link GlobalCombineFn}.
+     */
+    public <InputT, OutputT> ComposedCombineFnWithContext<DataT> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        GlobalCombineFn<InputT, ?, OutputT> globalCombineFn,
+        TupleTag<OutputT> outputTag) {
+      checkUniqueness(outputTags, outputTag);
+      return new ComposedCombineFnWithContext<>(
+          ImmutableList.<SerializableFunction<DataT, ?>>builder()
+              .addAll(extractInputFns)
+              .add(extractInputFn)
+              .build(),
+          ImmutableList.<CombineFnWithContext<?, ?, ?>>builder()
+              .addAll(combineFnWithContexts)
+              .add(toFnWithContext(globalCombineFn))
+              .build(),
+          ImmutableList.<TupleTag<?>>builder()
+              .addAll(outputTags)
+              .add(outputTag)
+              .build());
+    }
+
+    @Override
+    public Object[] createAccumulator(Context c) {
+      Object[] accumsArray = new Object[combineFnCount];
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumsArray[i] = combineFnWithContexts.get(i).createAccumulator(c);
+      }
+      return accumsArray;
+    }
+
+    @Override
+    public Object[] addInput(Object[] accumulator, DataT value, Context c) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        Object input = extractInputFns.get(i).apply(value);
+        accumulator[i] = combineFnWithContexts.get(i).addInput(accumulator[i], input, c);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Object[] mergeAccumulators(Iterable<Object[]> accumulators, Context c) {
+      Iterator<Object[]> iter = accumulators.iterator();
+      if (!iter.hasNext()) {
+        return createAccumulator(c);
+      } else {
+        // Reuses the first accumulator, and overwrites its values.
+        // It is safe because {@code accum[i]} only depends on
+        // the i-th component of each accumulator.
+        Object[] accum = iter.next();
+        for (int i = 0; i < combineFnCount; ++i) {
+          accum[i] = combineFnWithContexts.get(i).mergeAccumulators(
+              new ProjectionIterable(accumulators, i), c);
+        }
+        return accum;
+      }
+    }
+
+    @Override
+    public CoCombineResult extractOutput(Object[] accumulator, Context c) {
+      Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+      for (int i = 0; i < combineFnCount; ++i) {
+        valuesMap.put(
+            outputTags.get(i),
+            combineFnWithContexts.get(i).extractOutput(accumulator[i], c));
+      }
+      return new CoCombineResult(valuesMap);
+    }
+
+    @Override
+    public Object[] compact(Object[] accumulator, Context c) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumulator[i] = combineFnWithContexts.get(i).compact(accumulator[i], c);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<DataT> dataCoder)
+        throws CannotProvideCoderException {
+      List<Coder<Object>> coders = Lists.newArrayList();
+      for (int i = 0; i < combineFnCount; ++i) {
+        Coder<Object> inputCoder =
+            registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+        coders.add(combineFnWithContexts.get(i).getAccumulatorCoder(registry, inputCoder));
+      }
+      return new ComposedAccumulatorCoder(coders);
+    }
+  }
+
+  /**
+   * A composed {@link KeyedCombineFn} that applies multiple {@link KeyedCombineFn KeyedCombineFns}.
+   *
+   * <p>For each {@link KeyedCombineFn} it extracts inputs from {@code DataT} with
+   * the {@code extractInputFn} and combines them,
+   * and then it outputs each combined value with a {@link TupleTag} to a
+   * {@link CoCombineResult}.
+   */
+  public static class ComposedKeyedCombineFn<DataT, K>
+      extends KeyedCombineFn<K, DataT, Object[], CoCombineResult> {
+
+    private final List<SerializableFunction<DataT, Object>> extractInputFns;
+    private final List<KeyedCombineFn<K, Object, Object, Object>> keyedCombineFns;
+    private final List<TupleTag<?>> outputTags;
+    private final int combineFnCount;
+
+    private ComposedKeyedCombineFn() {
+      this.extractInputFns = ImmutableList.of();
+      this.keyedCombineFns = ImmutableList.of();
+      this.outputTags = ImmutableList.of();
+      this.combineFnCount = 0;
+    }
+
+    private ComposedKeyedCombineFn(
+        ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+        ImmutableList<KeyedCombineFn<K, ?, ?, ?>> keyedCombineFns,
+        ImmutableList<TupleTag<?>> outputTags) {
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<SerializableFunction<DataT, Object>> castedExtractInputFns = (List) extractInputFns;
+      this.extractInputFns = castedExtractInputFns;
+
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<KeyedCombineFn<K, Object, Object, Object>> castedKeyedCombineFns =
+          (List) keyedCombineFns;
+      this.keyedCombineFns = castedKeyedCombineFns;
+      this.outputTags = outputTags;
+      this.combineFnCount = this.keyedCombineFns.size();
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFn} with an additional {@link KeyedCombineFn}.
+     */
+    public <InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        KeyedCombineFn<K, InputT, ?, OutputT> keyedCombineFn,
+        TupleTag<OutputT> outputTag) {
+      checkUniqueness(outputTags, outputTag);
+      return new ComposedKeyedCombineFn<>(
+          ImmutableList.<SerializableFunction<DataT, ?>>builder()
+          .addAll(extractInputFns)
+          .add(extractInputFn)
+          .build(),
+      ImmutableList.<KeyedCombineFn<K, ?, ?, ?>>builder()
+          .addAll(keyedCombineFns)
+          .add(keyedCombineFn)
+          .build(),
+      ImmutableList.<TupleTag<?>>builder()
+          .addAll(outputTags)
+          .add(outputTag)
+          .build());
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+     * {@link KeyedCombineFnWithContext}.
+     */
+    public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        KeyedCombineFnWithContext<K, InputT, ?, OutputT> keyedCombineFn,
+        TupleTag<OutputT> outputTag) {
+      checkUniqueness(outputTags, outputTag);
+      List<KeyedCombineFnWithContext<K, Object, Object, Object>> fnsWithContext =
+          Lists.newArrayList();
+      for (KeyedCombineFn<K, Object, Object, Object> fn : keyedCombineFns) {
+        fnsWithContext.add(toFnWithContext(fn));
+      }
+      return new ComposedKeyedCombineFnWithContext<>(
+          ImmutableList.<SerializableFunction<DataT, ?>>builder()
+          .addAll(extractInputFns)
+          .add(extractInputFn)
+          .build(),
+      ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder()
+          .addAll(fnsWithContext)
+          .add(keyedCombineFn)
+          .build(),
+      ImmutableList.<TupleTag<?>>builder()
+          .addAll(outputTags)
+          .add(outputTag)
+          .build());
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFn} with an additional {@link CombineFn}.
+     */
+    public <InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFn<InputT, ?, OutputT> keyedCombineFn,
+        TupleTag<OutputT> outputTag) {
+      return with(extractInputFn, keyedCombineFn.<K>asKeyedFn(), outputTag);
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+     * {@link CombineFnWithContext}.
+     */
+    public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        CombineFnWithContext<InputT, ?, OutputT> keyedCombineFn,
+        TupleTag<OutputT> outputTag) {
+      return with(extractInputFn, keyedCombineFn.<K>asKeyedFn(), outputTag);
+    }
+
+    @Override
+    public Object[] createAccumulator(K key) {
+      Object[] accumsArray = new Object[combineFnCount];
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key);
+      }
+      return accumsArray;
+    }
+
+    @Override
+    public Object[] addInput(K key, Object[] accumulator, DataT value) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        Object input = extractInputFns.get(i).apply(value);
+        accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Object[] mergeAccumulators(K key, final Iterable<Object[]> accumulators) {
+      Iterator<Object[]> iter = accumulators.iterator();
+      if (!iter.hasNext()) {
+        return createAccumulator(key);
+      } else {
+        // Reuses the first accumulator, and overwrites its values.
+        // It is safe because {@code accum[i]} only depends on
+        // the i-th component of each accumulator.
+        Object[] accum = iter.next();
+        for (int i = 0; i < combineFnCount; ++i) {
+          accum[i] = keyedCombineFns.get(i).mergeAccumulators(
+              key, new ProjectionIterable(accumulators, i));
+        }
+        return accum;
+      }
+    }
+
+    @Override
+    public CoCombineResult extractOutput(K key, Object[] accumulator) {
+      Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+      for (int i = 0; i < combineFnCount; ++i) {
+        valuesMap.put(
+            outputTags.get(i),
+            keyedCombineFns.get(i).extractOutput(key, accumulator[i]));
+      }
+      return new CoCombineResult(valuesMap);
+    }
+
+    @Override
+    public Object[] compact(K key, Object[] accumulator) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i]);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Coder<Object[]> getAccumulatorCoder(
+        CoderRegistry registry, Coder<K> keyCoder, Coder<DataT> dataCoder)
+        throws CannotProvideCoderException {
+      List<Coder<Object>> coders = Lists.newArrayList();
+      for (int i = 0; i < combineFnCount; ++i) {
+        Coder<Object> inputCoder =
+            registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+        coders.add(keyedCombineFns.get(i).getAccumulatorCoder(registry, keyCoder, inputCoder));
+      }
+      return new ComposedAccumulatorCoder(coders);
+    }
+  }
+
+  /**
+   * A composed {@link KeyedCombineFnWithContext} that applies multiple
+   * {@link KeyedCombineFnWithContext KeyedCombineFnWithContexts}.
+   *
+   * <p>For each {@link KeyedCombineFnWithContext} it extracts inputs from {@code DataT} with
+   * the {@code extractInputFn} and combines them,
+   * and then it outputs each combined value with a {@link TupleTag} to a
+   * {@link CoCombineResult}.
+   */
+  public static class ComposedKeyedCombineFnWithContext<DataT, K>
+      extends KeyedCombineFnWithContext<K, DataT, Object[], CoCombineResult> {
+
+    private final List<SerializableFunction<DataT, Object>> extractInputFns;
+    private final List<KeyedCombineFnWithContext<K, Object, Object, Object>> keyedCombineFns;
+    private final List<TupleTag<?>> outputTags;
+    private final int combineFnCount;
+
+    private ComposedKeyedCombineFnWithContext() {
+      this.extractInputFns = ImmutableList.of();
+      this.keyedCombineFns = ImmutableList.of();
+      this.outputTags = ImmutableList.of();
+      this.combineFnCount = 0;
+    }
+
+    private ComposedKeyedCombineFnWithContext(
+        ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+        ImmutableList<KeyedCombineFnWithContext<K, ?, ?, ?>> keyedCombineFns,
+        ImmutableList<TupleTag<?>> outputTags) {
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<SerializableFunction<DataT, Object>> castedExtractInputFns =
+          (List) extractInputFns;
+      this.extractInputFns = castedExtractInputFns;
+
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      List<KeyedCombineFnWithContext<K, Object, Object, Object>> castedKeyedCombineFns =
+          (List) keyedCombineFns;
+      this.keyedCombineFns = castedKeyedCombineFns;
+      this.outputTags = outputTags;
+      this.combineFnCount = this.keyedCombineFns.size();
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+     * {@link PerKeyCombineFn}.
+     */
+    public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        PerKeyCombineFn<K, InputT, ?, OutputT> perKeyCombineFn,
+        TupleTag<OutputT> outputTag) {
+      checkUniqueness(outputTags, outputTag);
+      return new ComposedKeyedCombineFnWithContext<>(
+          ImmutableList.<SerializableFunction<DataT, ?>>builder()
+              .addAll(extractInputFns)
+              .add(extractInputFn)
+              .build(),
+          ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder()
+              .addAll(keyedCombineFns)
+              .add(toFnWithContext(perKeyCombineFn))
+              .build(),
+          ImmutableList.<TupleTag<?>>builder()
+              .addAll(outputTags)
+              .add(outputTag)
+              .build());
+    }
+
+    /**
+     * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+     * {@link GlobalCombineFn}.
+     */
+    public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+        SimpleFunction<DataT, InputT> extractInputFn,
+        GlobalCombineFn<InputT, ?, OutputT> perKeyCombineFn,
+        TupleTag<OutputT> outputTag) {
+      return with(extractInputFn, perKeyCombineFn.<K>asKeyedFn(), outputTag);
+    }
+
+    @Override
+    public Object[] createAccumulator(K key, Context c) {
+      Object[] accumsArray = new Object[combineFnCount];
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key, c);
+      }
+      return accumsArray;
+    }
+
+    @Override
+    public Object[] addInput(K key, Object[] accumulator, DataT value, Context c) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        Object input = extractInputFns.get(i).apply(value);
+        accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input, c);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Object[] mergeAccumulators(K key, Iterable<Object[]> accumulators, Context c) {
+      Iterator<Object[]> iter = accumulators.iterator();
+      if (!iter.hasNext()) {
+        return createAccumulator(key, c);
+      } else {
+        // Reuses the first accumulator, and overwrites its values.
+        // It is safe because {@code accum[i]} only depends on
+        // the i-th component of each accumulator.
+        Object[] accum = iter.next();
+        for (int i = 0; i < combineFnCount; ++i) {
+          accum[i] = keyedCombineFns.get(i).mergeAccumulators(
+              key, new ProjectionIterable(accumulators, i), c);
+        }
+        return accum;
+      }
+    }
+
+    @Override
+    public CoCombineResult extractOutput(K key, Object[] accumulator, Context c) {
+      Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+      for (int i = 0; i < combineFnCount; ++i) {
+        valuesMap.put(
+            outputTags.get(i),
+            keyedCombineFns.get(i).extractOutput(key, accumulator[i], c));
+      }
+      return new CoCombineResult(valuesMap);
+    }
+
+    @Override
+    public Object[] compact(K key, Object[] accumulator, Context c) {
+      for (int i = 0; i < combineFnCount; ++i) {
+        accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i], c);
+      }
+      return accumulator;
+    }
+
+    @Override
+    public Coder<Object[]> getAccumulatorCoder(
+        CoderRegistry registry, Coder<K> keyCoder, Coder<DataT> dataCoder)
+        throws CannotProvideCoderException {
+      List<Coder<Object>> coders = Lists.newArrayList();
+      for (int i = 0; i < combineFnCount; ++i) {
+        Coder<Object> inputCoder =
+            registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+        coders.add(keyedCombineFns.get(i).getAccumulatorCoder(
+            registry, keyCoder, inputCoder));
+      }
+      return new ComposedAccumulatorCoder(coders);
+    }
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  private static class ProjectionIterable implements Iterable<Object> {
+    private final Iterable<Object[]> iterable;
+    private final int column;
+
+    private ProjectionIterable(Iterable<Object[]> iterable, int column) {
+      this.iterable = iterable;
+      this.column = column;
+    }
+
+    @Override
+    public Iterator<Object> iterator() {
+      final Iterator<Object[]> iter = iterable.iterator();
+      return new Iterator<Object>() {
+        @Override
+        public boolean hasNext() {
+          return iter.hasNext();
+        }
+
+        @Override
+        public Object next() {
+          return iter.next()[column];
+        }
+
+        @Override
+        public void remove() {
+            throw new UnsupportedOperationException();
+        }
+      };
+    }
+  }
+
+  private static class ComposedAccumulatorCoder extends StandardCoder<Object[]> {
+    private List<Coder<Object>> coders;
+    private int codersCount;
+
+    public ComposedAccumulatorCoder(List<Coder<Object>> coders) {
+      this.coders = ImmutableList.copyOf(coders);
+      this.codersCount  = coders.size();
+    }
+
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    @JsonCreator
+    public static ComposedAccumulatorCoder of(
+        @JsonProperty(PropertyNames.COMPONENT_ENCODINGS)
+        List<Coder<?>> components) {
+      return new ComposedAccumulatorCoder((List) components);
+    }
+
+    @Override
+    public void encode(Object[] value, OutputStream outStream, Context context)
+        throws CoderException, IOException {
+      checkArgument(value.length == codersCount);
+      Context nestedContext = context.nested();
+      for (int i = 0; i < codersCount; ++i) {
+        coders.get(i).encode(value[i], outStream, nestedContext);
+      }
+    }
+
+    @Override
+    public Object[] decode(InputStream inStream, Context context)
+        throws CoderException, IOException {
+      Object[] ret = new Object[codersCount];
+      Context nestedContext = context.nested();
+      for (int i = 0; i < codersCount; ++i) {
+        ret[i] = coders.get(i).decode(inStream, nestedContext);
+      }
+      return ret;
+    }
+
+    @Override
+    public List<? extends Coder<?>> getCoderArguments() {
+      return coders;
+    }
+
+    @Override
+    public void verifyDeterministic() throws NonDeterministicException {
+      for (int i = 0; i < codersCount; ++i) {
+        coders.get(i).verifyDeterministic();
+      }
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  private static <InputT, AccumT, OutputT> CombineFnWithContext<InputT, AccumT, OutputT>
+  toFnWithContext(GlobalCombineFn<InputT, AccumT, OutputT> globalCombineFn) {
+    if (globalCombineFn instanceof CombineFnWithContext) {
+      return (CombineFnWithContext<InputT, AccumT, OutputT>) globalCombineFn;
+    } else {
+      final CombineFn<InputT, AccumT, OutputT> combineFn =
+          (CombineFn<InputT, AccumT, OutputT>) globalCombineFn;
+      return new CombineFnWithContext<InputT, AccumT, OutputT>() {
+        @Override
+        public AccumT createAccumulator(Context c) {
+          return combineFn.createAccumulator();
+        }
+        @Override
+        public AccumT addInput(AccumT accumulator, InputT input, Context c) {
+          return combineFn.addInput(accumulator, input);
+        }
+        @Override
+        public AccumT mergeAccumulators(Iterable<AccumT> accumulators, Context c) {
+          return combineFn.mergeAccumulators(accumulators);
+        }
+        @Override
+        public OutputT extractOutput(AccumT accumulator, Context c) {
+          return combineFn.extractOutput(accumulator);
+        }
+        @Override
+        public AccumT compact(AccumT accumulator, Context c) {
+          return combineFn.compact(accumulator);
+        }
+        @Override
+        public OutputT defaultValue() {
+          return combineFn.defaultValue();
+        }
+        @Override
+        public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
+            throws CannotProvideCoderException {
+          return combineFn.getAccumulatorCoder(registry, inputCoder);
+        }
+        @Override
+        public Coder<OutputT> getDefaultOutputCoder(
+            CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException {
+          return combineFn.getDefaultOutputCoder(registry, inputCoder);
+        }
+      };
+    }
+  }
+
+  private static <K, InputT, AccumT, OutputT> KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>
+  toFnWithContext(PerKeyCombineFn<K, InputT, AccumT, OutputT> perKeyCombineFn) {
+    if (perKeyCombineFn instanceof KeyedCombineFnWithContext) {
+      @SuppressWarnings("unchecked")
+      KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> keyedCombineFnWithContext =
+          (KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) perKeyCombineFn;
+      return keyedCombineFnWithContext;
+    } else {
+      @SuppressWarnings("unchecked")
+      final KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn =
+          (KeyedCombineFn<K, InputT, AccumT, OutputT>) perKeyCombineFn;
+      return new KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() {
+        @Override
+        public AccumT createAccumulator(K key, Context c) {
+          return keyedCombineFn.createAccumulator(key);
+        }
+        @Override
+        public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) {
+          return keyedCombineFn.addInput(key, accumulator, value);
+        }
+        @Override
+        public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, Context c) {
+          return keyedCombineFn.mergeAccumulators(key, accumulators);
+        }
+        @Override
+        public OutputT extractOutput(K key, AccumT accumulator, Context c) {
+          return keyedCombineFn.extractOutput(key, accumulator);
+        }
+        @Override
+        public AccumT compact(K key, AccumT accumulator, Context c) {
+          return keyedCombineFn.compact(key, accumulator);
+        }
+        @Override
+        public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<K> keyCoder,
+            Coder<InputT> inputCoder) throws CannotProvideCoderException {
+          return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder);
+        }
+        @Override
+        public Coder<OutputT> getDefaultOutputCoder(CoderRegistry registry, Coder<K> keyCoder,
+            Coder<InputT> inputCoder) throws CannotProvideCoderException {
+          return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder);
+        }
+      };
+    }
+  }
+
+  private static <OutputT> void checkUniqueness(
+      List<TupleTag<?>> registeredTags, TupleTag<OutputT> outputTag) {
+    checkArgument(
+        !registeredTags.contains(outputTag),
+        "Cannot compose with tuple tag %s because it is already present in the composition.",
+        outputTag);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/23b43780/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java
----------------------------------------------------------------------
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java
new file mode 100644
index 0000000..ad37708
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java
@@ -0,0 +1,413 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.transforms;
+
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.coders.KvCoder;
+import com.google.cloud.dataflow.sdk.coders.NullableCoder;
+import com.google.cloud.dataflow.sdk.coders.StandardCoder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.testing.RunnableOnService;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineFns.CoCombineResult;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext;
+import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn;
+import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
+import com.google.common.collect.ImmutableList;
+
+import org.hamcrest.Matchers;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Unit tests for {@link CombineFns}.
+ */
+@RunWith(JUnit4.class)
+public class CombineFnsTest {
+  @Rule public ExpectedException expectedException = ExpectedException.none();
+
+  @Test
+  public void testDuplicatedTags() {
+    expectedException.expect(IllegalArgumentException.class);
+    expectedException.expectMessage("it is already present in the composition");
+
+    TupleTag<Integer> tag = new TupleTag<Integer>();
+    CombineFns.compose()
+      .with(new GetIntegerFunction(), new MaxIntegerFn(), tag)
+      .with(new GetIntegerFunction(), new MinIntegerFn(), tag);
+  }
+
+  @Test
+  public void testDuplicatedTagsKeyed() {
+    expectedException.expect(IllegalArgumentException.class);
+    expectedException.expectMessage("it is already present in the composition");
+
+    TupleTag<Integer> tag = new TupleTag<Integer>();
+    CombineFns.composeKeyed()
+      .with(new GetIntegerFunction(), new MaxIntegerFn(), tag)
+      .with(new GetIntegerFunction(), new MinIntegerFn(), tag);
+  }
+
+  @Test
+  public void testDuplicatedTagsWithContext() {
+    expectedException.expect(IllegalArgumentException.class);
+    expectedException.expectMessage("it is already present in the composition");
+
+    TupleTag<UserString> tag = new TupleTag<UserString>();
+    CombineFns.compose()
+      .with(
+          new GetUserStringFunction(),
+          new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()),
+          tag)
+      .with(
+          new GetUserStringFunction(),
+          new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()),
+          tag);
+  }
+
+  @Test
+  public void testDuplicatedTagsWithContextKeyed() {
+    expectedException.expect(IllegalArgumentException.class);
+    expectedException.expectMessage("it is already present in the composition");
+
+    TupleTag<UserString> tag = new TupleTag<UserString>();
+    CombineFns.composeKeyed()
+      .with(
+          new GetUserStringFunction(),
+          new ConcatStringWithContext(null /* view */),
+          tag)
+      .with(
+          new GetUserStringFunction(),
+          new ConcatStringWithContext(null /* view */),
+          tag);
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testComposedCombine() {
+    Pipeline p = TestPipeline.create();
+    p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of());
+
+    PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply(
+        Create.timestamped(
+            Arrays.asList(
+                KV.of("a", KV.of(1, UserString.of("1"))),
+                KV.of("a", KV.of(1, UserString.of("1"))),
+                KV.of("a", KV.of(4, UserString.of("4"))),
+                KV.of("b", KV.of(1, UserString.of("1"))),
+                KV.of("b", KV.of(13, UserString.of("13")))),
+            Arrays.asList(0L, 4L, 7L, 10L, 16L))
+        .withCoder(KvCoder.of(
+            StringUtf8Coder.of(),
+            KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of()))));
+
+    TupleTag<Integer> maxIntTag = new TupleTag<Integer>();
+    TupleTag<UserString> concatStringTag = new TupleTag<UserString>();
+    PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput
+        .apply(Values.<KV<Integer, UserString>>create())
+        .apply(Combine.globally(CombineFns.compose()
+            .with(
+                new GetIntegerFunction(),
+                new MaxIntegerFn(),
+                maxIntTag)
+            .with(
+                new GetUserStringFunction(),
+                new ConcatString(),
+                concatStringTag)))
+        .apply(WithKeys.<String, CoCombineResult>of("global"))
+        .apply(
+            "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+
+    PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput
+        .apply(Combine.perKey(CombineFns.composeKeyed()
+            .with(
+                new GetIntegerFunction(),
+                new MaxIntegerFn().<String>asKeyedFn(),
+                maxIntTag)
+            .with(
+                new GetUserStringFunction(),
+                new ConcatString().<String>asKeyedFn(),
+                concatStringTag)))
+        .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+    DataflowAssert.that(combineGlobally).containsInAnyOrder(
+        KV.of("global", KV.of(13, "111134")));
+    DataflowAssert.that(combinePerKey).containsInAnyOrder(
+        KV.of("a", KV.of(4, "114")),
+        KV.of("b", KV.of(13, "113")));
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testComposedCombineWithContext() {
+    Pipeline p = TestPipeline.create();
+    p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of());
+
+    PCollectionView<String> view = p
+        .apply(Create.of("I"))
+        .apply(View.<String>asSingleton());
+
+    PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply(
+        Create.timestamped(
+            Arrays.asList(
+                KV.of("a", KV.of(1, UserString.of("1"))),
+                KV.of("a", KV.of(1, UserString.of("1"))),
+                KV.of("a", KV.of(4, UserString.of("4"))),
+                KV.of("b", KV.of(1, UserString.of("1"))),
+                KV.of("b", KV.of(13, UserString.of("13")))),
+            Arrays.asList(0L, 4L, 7L, 10L, 16L))
+        .withCoder(KvCoder.of(
+            StringUtf8Coder.of(),
+            KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of()))));
+
+    TupleTag<Integer> maxIntTag = new TupleTag<Integer>();
+    TupleTag<UserString> concatStringTag = new TupleTag<UserString>();
+    PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput
+        .apply(Values.<KV<Integer, UserString>>create())
+        .apply(Combine.globally(CombineFns.compose()
+            .with(
+                new GetIntegerFunction(),
+                new MaxIntegerFn(),
+                maxIntTag)
+            .with(
+                new GetUserStringFunction(),
+                new ConcatStringWithContext(view).forKey("G", StringUtf8Coder.of()),
+                concatStringTag))
+            .withoutDefaults()
+            .withSideInputs(ImmutableList.of(view)))
+        .apply(WithKeys.<String, CoCombineResult>of("global"))
+        .apply(
+            "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+
+    PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput
+        .apply(Combine.perKey(CombineFns.composeKeyed()
+            .with(
+                new GetIntegerFunction(),
+                new MaxIntegerFn().<String>asKeyedFn(),
+                maxIntTag)
+            .with(
+                new GetUserStringFunction(),
+                new ConcatStringWithContext(view),
+                concatStringTag))
+            .withSideInputs(ImmutableList.of(view)))
+        .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+    DataflowAssert.that(combineGlobally).containsInAnyOrder(
+        KV.of("global", KV.of(13, "111134GI")));
+    DataflowAssert.that(combinePerKey).containsInAnyOrder(
+        KV.of("a", KV.of(4, "114Ia")),
+        KV.of("b", KV.of(13, "113Ib")));
+    p.run();
+  }
+
+  @Test
+  @Category(RunnableOnService.class)
+  public void testComposedCombineNullValues() {
+    Pipeline p = TestPipeline.create();
+    p.getCoderRegistry().registerCoder(UserString.class, NullableCoder.of(UserStringCoder.of()));
+    p.getCoderRegistry().registerCoder(String.class, NullableCoder.of(StringUtf8Coder.of()));
+
+    PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply(
+        Create.timestamped(
+            Arrays.asList(
+                KV.of("a", KV.of(1, UserString.of("1"))),
+                KV.of("a", KV.of(1, UserString.of("1"))),
+                KV.of("a", KV.of(4, UserString.of("4"))),
+                KV.of("b", KV.of(1, UserString.of("1"))),
+                KV.of("b", KV.of(13, UserString.of("13")))),
+            Arrays.asList(0L, 4L, 7L, 10L, 16L))
+        .withCoder(KvCoder.of(
+            StringUtf8Coder.of(),
+            KvCoder.of(
+                BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of())))));
+
+    TupleTag<Integer> maxIntTag = new TupleTag<Integer>();
+    TupleTag<UserString> concatStringTag = new TupleTag<UserString>();
+
+    PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput
+        .apply(Combine.perKey(CombineFns.composeKeyed()
+            .with(
+                new GetIntegerFunction(),
+                new MaxIntegerFn().<String>asKeyedFn(),
+                maxIntTag)
+            .with(
+                new GetUserStringFunction(),
+                new OutputNullString().<String>asKeyedFn(),
+                concatStringTag)))
+        .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+    DataflowAssert.that(combinePerKey).containsInAnyOrder(
+        KV.of("a", KV.of(4, (String) null)),
+        KV.of("b", KV.of(13, (String) null)));
+    p.run();
+  }
+
+  private static class UserString implements Serializable {
+    private String strValue;
+
+    static UserString of(String strValue) {
+      UserString ret = new UserString();
+      ret.strValue = strValue;
+      return ret;
+    }
+  }
+
+  private static class UserStringCoder extends StandardCoder<UserString> {
+    public static UserStringCoder of() {
+      return INSTANCE;
+    }
+
+    private static final UserStringCoder INSTANCE = new UserStringCoder();
+
+    @Override
+    public void encode(UserString value, OutputStream outStream, Context context)
+        throws CoderException, IOException {
+      StringUtf8Coder.of().encode(value.strValue, outStream, context);
+    }
+
+    @Override
+    public UserString decode(InputStream inStream, Context context)
+        throws CoderException, IOException {
+      return UserString.of(StringUtf8Coder.of().decode(inStream, context));
+    }
+
+    @Override
+    public List<? extends Coder<?>> getCoderArguments() {
+      return null;
+    }
+
+    @Override
+    public void verifyDeterministic() throws NonDeterministicException {}
+  }
+
+  private static class GetIntegerFunction
+      extends SimpleFunction<KV<Integer, UserString>, Integer> {
+    @Override
+    public Integer apply(KV<Integer, UserString> input) {
+      return input.getKey();
+    }
+  }
+
+  private static class GetUserStringFunction
+      extends SimpleFunction<KV<Integer, UserString>, UserString> {
+    @Override
+    public UserString apply(KV<Integer, UserString> input) {
+      return input.getValue();
+    }
+  }
+
+  private static class ConcatString extends BinaryCombineFn<UserString> {
+    @Override
+    public UserString apply(UserString left, UserString right) {
+      String retStr = left.strValue + right.strValue;
+      char[] chars = retStr.toCharArray();
+      Arrays.sort(chars);
+      return UserString.of(new String(chars));
+    }
+  }
+
+  private static class OutputNullString extends BinaryCombineFn<UserString> {
+    @Override
+    public UserString apply(UserString left, UserString right) {
+      return null;
+    }
+  }
+
+  private static class ConcatStringWithContext
+      extends KeyedCombineFnWithContext<String, UserString, UserString, UserString> {
+    private final PCollectionView<String> view;
+
+    private ConcatStringWithContext(PCollectionView<String> view) {
+      this.view = view;
+    }
+
+    @Override
+    public UserString createAccumulator(String key, CombineWithContext.Context c) {
+      return UserString.of(key + c.sideInput(view));
+    }
+
+    @Override
+    public UserString addInput(
+        String key, UserString accumulator, UserString input, CombineWithContext.Context c) {
+      assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view)));
+      accumulator.strValue += input.strValue;
+      return accumulator;
+    }
+
+    @Override
+    public UserString mergeAccumulators(
+        String key, Iterable<UserString> accumulators, CombineWithContext.Context c) {
+      String keyPrefix = key + c.sideInput(view);
+      String all = keyPrefix;
+      for (UserString accumulator : accumulators) {
+        assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix));
+        all += accumulator.strValue.substring(keyPrefix.length());
+        accumulator.strValue = "cleared in mergeAccumulators";
+      }
+      return UserString.of(all);
+    }
+
+    @Override
+    public UserString extractOutput(
+        String key, UserString accumulator, CombineWithContext.Context c) {
+      assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view)));
+      char[] chars = accumulator.strValue.toCharArray();
+      Arrays.sort(chars);
+      return UserString.of(new String(chars));
+    }
+  }
+
+  private static class ExtractResultDoFn
+      extends DoFn<KV<String, CoCombineResult>, KV<String, KV<Integer, String>>>{
+
+    private final TupleTag<Integer> maxIntTag;
+    private final TupleTag<UserString> concatStringTag;
+
+    ExtractResultDoFn(TupleTag<Integer> maxIntTag, TupleTag<UserString> concatStringTag) {
+      this.maxIntTag = maxIntTag;
+      this.concatStringTag = concatStringTag;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      UserString userString = c.element().getValue().get(concatStringTag);
+      KV<Integer, String> value = KV.of(
+          c.element().getValue().get(maxIntTag),
+          userString == null ? null : userString.strValue);
+      c.output(KV.of(c.element().getKey(), value));
+    }
+  }
+}


Mime
View raw message