Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 0622D200B89 for ; Wed, 21 Sep 2016 19:25:23 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 0491D160ADB; Wed, 21 Sep 2016 17:25:23 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 4EA5A160ABC for ; Wed, 21 Sep 2016 19:25:20 +0200 (CEST) Received: (qmail 58339 invoked by uid 500); 21 Sep 2016 17:25:19 -0000 Mailing-List: contact commits-help@beam.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@beam.incubator.apache.org Delivered-To: mailing list commits@beam.incubator.apache.org Received: (qmail 58330 invoked by uid 99); 21 Sep 2016 17:25:19 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd2-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 21 Sep 2016 17:25:19 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd2-us-west.apache.org (ASF Mail Server at spamd2-us-west.apache.org) with ESMTP id E6C021A02ED for ; Wed, 21 Sep 2016 17:25:18 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd2-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -4.152 X-Spam-Level: X-Spam-Status: No, score=-4.152 tagged_above=-999 required=6.31 tests=[FUZZY_VPILL=0.494, KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-1.426] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd2-us-west.apache.org [10.40.0.9]) (amavisd-new, port 10024) with ESMTP id eFk5PvJMY481 for ; Wed, 21 Sep 2016 17:25:10 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id 7D4C25F396 for ; Wed, 21 Sep 2016 17:25:07 +0000 (UTC) Received: (qmail 57900 invoked by uid 99); 21 Sep 2016 17:25:06 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 21 Sep 2016 17:25:06 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 7E311E05D9; Wed, 21 Sep 2016 17:25:06 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: amitsela@apache.org To: commits@beam.incubator.apache.org Date: Wed, 21 Sep 2016 17:25:07 -0000 Message-Id: In-Reply-To: <2d27370e81ef4c75aefadc877f58b5f0@git.apache.org> References: <2d27370e81ef4c75aefadc877f58b5f0@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [2/4] incubator-beam git commit: [BEAM-610] Enable spark's checkpointing mechanism for driver-failure recovery in streaming. archived-at: Wed, 21 Sep 2016 17:25:23 -0000 http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 8341c6d..1a0511f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -19,39 +19,32 @@ package org.apache.beam.runners.spark.translation; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputDirectory; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFileTemplate; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceShardCount; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; -import java.io.Serializable; -import java.lang.reflect.Field; -import java.util.Arrays; import java.util.Collections; -import java.util.List; import java.util.Map; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroJob; import org.apache.avro.mapreduce.AvroKeyInputFormat; import org.apache.beam.runners.core.AssignWindowsDoFn; -import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn; import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow; import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; -import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.hadoop.HadoopIO; import org.apache.beam.runners.spark.io.hadoop.ShardNameTemplateHelper; import org.apache.beam.runners.spark.io.hadoop.TemplatedAvroKeyOutputFormat; import org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat; import org.apache.beam.runners.spark.util.BroadcastHelper; -import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.AvroIO; import org.apache.beam.sdk.io.TextIO; @@ -63,36 +56,30 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; -import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.util.state.InMemoryStateInternals; -import org.apache.beam.sdk.util.state.StateInternals; -import org.apache.beam.sdk.util.state.StateInternalsFactory; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDDLike; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; + import scala.Tuple2; + /** * Supports translation between a Beam transform, and Spark's operations on RDDs. */ @@ -101,31 +88,6 @@ public final class TransformTranslator { private TransformTranslator() { } - /** - * Getter of the field. - */ - public static class FieldGetter { - private final Map fields; - - public FieldGetter(Class clazz) { - this.fields = Maps.newHashMap(); - for (Field f : clazz.getDeclaredFields()) { - f.setAccessible(true); - this.fields.put(f.getName(), f); - } - } - - public T get(String fieldname, Object value) { - try { - @SuppressWarnings("unchecked") - T fieldValue = (T) fields.get(fieldname).get(value); - return fieldValue; - } catch (IllegalAccessException e) { - throw new IllegalStateException(e); - } - } - } - private static TransformEvaluator> flattenPColl() { return new TransformEvaluator>() { @SuppressWarnings("unchecked") @@ -142,28 +104,18 @@ public final class TransformTranslator { }; } - private static TransformEvaluator> gbk() { + private static TransformEvaluator> gbko() { return new TransformEvaluator>() { @Override public void evaluate(GroupByKeyOnly transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike>, ?> inRDD = - (JavaRDDLike>, ?>) context.getInputRDD(transform); + JavaRDD>> inRDD = + (JavaRDD>>) context.getInputRDD(transform); + @SuppressWarnings("unchecked") - KvCoder coder = (KvCoder) context.getInput(transform).getCoder(); - Coder keyCoder = coder.getKeyCoder(); - Coder valueCoder = coder.getValueCoder(); - - // Use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle. - JavaRDDLike>>, ?> outRDD = fromPair( - toPair(inRDD.map(WindowingHelpers.>unwindowFunction())) - .mapToPair(CoderHelpers.toByteFunction(keyCoder, valueCoder)) - .groupByKey() - .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, valueCoder))) - // empty windows are OK here, see GroupByKey#evaluateHelper in the SDK - .map(WindowingHelpers.>>windowFunction()); - context.setOutputRDD(transform, outRDD); + final KvCoder coder = (KvCoder) context.getInput(transform).getCoder(); + + context.setOutputRDD(transform, GroupCombineFunctions.groupByKeyOnly(inRDD, coder)); } }; } @@ -174,81 +126,52 @@ public final class TransformTranslator { @Override public void evaluate(GroupAlsoByWindow transform, EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike>>>, ?> inRDD = - (JavaRDDLike>>>, ?>) + JavaRDD>>>> inRDD = + (JavaRDD>>>>) context.getInputRDD(transform); - Coder>>> inputCoder = - context.getInput(transform).getCoder(); - Coder keyCoder = transform.getKeyCoder(inputCoder); - Coder valueCoder = transform.getValueCoder(inputCoder); - @SuppressWarnings("unchecked") - KvCoder>> inputKvCoder = + final KvCoder>> inputKvCoder = (KvCoder>>) context.getInput(transform).getCoder(); - Coder>> inputValueCoder = inputKvCoder.getValueCoder(); - - IterableCoder> inputIterableValueCoder = - (IterableCoder>) inputValueCoder; - Coder> inputIterableElementCoder = inputIterableValueCoder.getElemCoder(); - WindowedValueCoder inputIterableWindowedValueCoder = - (WindowedValueCoder) inputIterableElementCoder; - Coder inputIterableElementValueCoder = inputIterableWindowedValueCoder.getValueCoder(); + final Accumulator accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); - @SuppressWarnings("unchecked") - WindowingStrategy windowingStrategy = - (WindowingStrategy) transform.getWindowingStrategy(); - - OldDoFn>>, KV>> gabwDoFn = - new GroupAlsoByWindowsViaOutputBufferDoFn, W>( - windowingStrategy, - new InMemoryStateInternalsFactory(), - SystemReduceFn.buffering(inputIterableElementValueCoder)); - - // GroupAlsoByWindow current uses a dummy in-memory StateInternals - JavaRDDLike>>, ?> outRDD = - inRDD.mapPartitions( - new DoFnFunction>>, KV>>( - gabwDoFn, context.getRuntimeContext(), null)); - - context.setOutputRDD(transform, outRDD); + context.setOutputRDD(transform, GroupCombineFunctions.groupAlsoByWindow(inRDD, transform, + context.getRuntimeContext(), accum, inputKvCoder)); } }; } - private static final FieldGetter GROUPED_FG = new FieldGetter(Combine.GroupedValues.class); - private static TransformEvaluator> grouped() { return new TransformEvaluator>() { @Override public void evaluate(Combine.GroupedValues transform, EvaluationContext context) { - Combine.KeyedCombineFn keyed = GROUPED_FG.get("fn", transform); @SuppressWarnings("unchecked") JavaRDDLike>>, ?> inRDD = - (JavaRDDLike>>, ?>) context.getInputRDD(transform); - context.setOutputRDD(transform, - inRDD.map(new KVFunction<>(keyed))); + (JavaRDDLike>>, ?>) + context.getInputRDD(transform); + context.setOutputRDD(transform, inRDD.map( + new TranslationUtils.CombineGroupedValues<>(transform))); } }; } - private static final FieldGetter COMBINE_GLOBALLY_FG = new FieldGetter(Combine.Globally.class); - private static TransformEvaluator> combineGlobally() { return new TransformEvaluator>() { @Override public void evaluate(Combine.Globally transform, EvaluationContext context) { - final Combine.CombineFn globally = - COMBINE_GLOBALLY_FG.get("fn", transform); + @SuppressWarnings("unchecked") + JavaRDD> inRdd = + (JavaRDD>) context.getInputRDD(transform); @SuppressWarnings("unchecked") - JavaRDDLike, ?> inRdd = - (JavaRDDLike, ?>) context.getInputRDD(transform); + final Combine.CombineFn globally = + (Combine.CombineFn) transform.getFn(); final Coder iCoder = context.getInput(transform).getCoder(); final Coder aCoder; @@ -259,61 +182,26 @@ public final class TransformTranslator { throw new IllegalStateException("Could not determine coder for accumulator", e); } - // Use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle. - JavaRDD inRddBytes = inRdd - .map(WindowingHelpers.unwindowFunction()) - .map(CoderHelpers.toByteFunction(iCoder)); - - /*AccumT*/ byte[] acc = inRddBytes.aggregate( - CoderHelpers.toByteArray(globally.createAccumulator(), aCoder), - new Function2() { - @Override - public /*AccumT*/ byte[] call(/*AccumT*/ byte[] ab, /*InputT*/ byte[] ib) - throws Exception { - AccumT a = CoderHelpers.fromByteArray(ab, aCoder); - InputT i = CoderHelpers.fromByteArray(ib, iCoder); - return CoderHelpers.toByteArray(globally.addInput(a, i), aCoder); - } - }, - new Function2() { - @Override - public /*AccumT*/ byte[] call(/*AccumT*/ byte[] a1b, /*AccumT*/ byte[] a2b) - throws Exception { - AccumT a1 = CoderHelpers.fromByteArray(a1b, aCoder); - AccumT a2 = CoderHelpers.fromByteArray(a2b, aCoder); - // don't use Guava's ImmutableList.of as values may be null - List accumulators = Collections.unmodifiableList(Arrays.asList(a1, a2)); - AccumT merged = globally.mergeAccumulators(accumulators); - return CoderHelpers.toByteArray(merged, aCoder); - } - } - ); - OutputT output = globally.extractOutput(CoderHelpers.fromByteArray(acc, aCoder)); - - Coder coder = context.getOutput(transform).getCoder(); + final Coder oCoder = context.getOutput(transform).getCoder(); JavaRDD outRdd = context.getSparkContext().parallelize( // don't use Guava's ImmutableList.of as output may be null - CoderHelpers.toByteArrays(Collections.singleton(output), coder)); - context.setOutputRDD(transform, outRdd.map(CoderHelpers.fromByteFunction(coder)) + CoderHelpers.toByteArrays(Collections.singleton( + GroupCombineFunctions.combineGlobally(inRdd, globally, iCoder, aCoder)), oCoder)); + context.setOutputRDD(transform, outRdd.map(CoderHelpers.fromByteFunction(oCoder)) .map(WindowingHelpers.windowFunction())); } }; } - private static final FieldGetter COMBINE_PERKEY_FG = new FieldGetter(Combine.PerKey.class); - private static TransformEvaluator> combinePerKey() { return new TransformEvaluator>() { @Override - public void evaluate(Combine.PerKey - transform, EvaluationContext context) { - final Combine.KeyedCombineFn keyed = - COMBINE_PERKEY_FG.get("fn", transform); + public void evaluate(Combine.PerKey transform, + EvaluationContext context) { @SuppressWarnings("unchecked") - JavaRDDLike>, ?> inRdd = - (JavaRDDLike>, ?>) context.getInputRDD(transform); + final Combine.KeyedCombineFn keyed = + (Combine.KeyedCombineFn) transform.getFn(); @SuppressWarnings("unchecked") KvCoder inputCoder = (KvCoder) @@ -329,214 +217,66 @@ public final class TransformTranslator { } Coder> kviCoder = KvCoder.of(keyCoder, viCoder); Coder> kvaCoder = KvCoder.of(keyCoder, vaCoder); - - // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value, - // since the functions passed to combineByKey don't receive the associated key of each - // value, and we need to map back into methods in Combine.KeyedCombineFn, which each - // require the key in addition to the InputT's and AccumT's being merged/accumulated. - // Once Spark provides a way to include keys in the arguments of combine/merge functions, - // we won't need to duplicate the keys anymore. - - // Key has to bw windowed in order to group by window as well - JavaPairRDD, WindowedValue>> inRddDuplicatedKeyPair = - inRdd.flatMapToPair( - new PairFlatMapFunction>, WindowedValue, - WindowedValue>>() { - @Override - public Iterable, - WindowedValue>>> - call(WindowedValue> kv) { - List, - WindowedValue>>> tuple2s = - Lists.newArrayListWithCapacity(kv.getWindows().size()); - for (BoundedWindow boundedWindow: kv.getWindows()) { - WindowedValue wk = WindowedValue.of(kv.getValue().getKey(), - boundedWindow.maxTimestamp(), boundedWindow, kv.getPane()); - tuple2s.add(new Tuple2<>(wk, kv)); - } - return tuple2s; - } - }); //-- windowed coders final WindowedValue.FullWindowedValueCoder wkCoder = - WindowedValue.FullWindowedValueCoder.of(keyCoder, + WindowedValue.FullWindowedValueCoder.of(keyCoder, context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder> wkviCoder = - WindowedValue.FullWindowedValueCoder.of(kviCoder, + WindowedValue.FullWindowedValueCoder.of(kviCoder, context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder> wkvaCoder = - WindowedValue.FullWindowedValueCoder.of(kvaCoder, + WindowedValue.FullWindowedValueCoder.of(kvaCoder, context.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); - // Use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle. - JavaPairRDD inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair - .mapToPair(CoderHelpers.toByteFunction(wkCoder, wkviCoder)); - - // The output of combineByKey will be "AccumT" (accumulator) - // types rather than "OutputT" (final output types) since Combine.CombineFn - // only provides ways to merge VAs, and no way to merge VOs. - JavaPairRDD*/ byte[]> accumulatedBytes = - inRddDuplicatedKeyPairBytes.combineByKey( - new Function*/ byte[], /*KV*/ byte[]>() { - @Override - public /*KV*/ byte[] call(/*KV*/ byte[] input) { - WindowedValue> wkvi = - CoderHelpers.fromByteArray(input, wkviCoder); - AccumT va = keyed.createAccumulator(wkvi.getValue().getKey()); - va = keyed.addInput(wkvi.getValue().getKey(), va, wkvi.getValue().getValue()); - WindowedValue> wkva = - WindowedValue.of(KV.of(wkvi.getValue().getKey(), va), wkvi.getTimestamp(), - wkvi.getWindows(), wkvi.getPane()); - return CoderHelpers.toByteArray(wkva, wkvaCoder); - } - }, - new Function2*/ byte[], - /*KV*/ byte[], - /*KV*/ byte[]>() { - @Override - public /*KV*/ byte[] call(/*KV*/ byte[] acc, - /*KV*/ byte[] input) { - WindowedValue> wkva = - CoderHelpers.fromByteArray(acc, wkvaCoder); - WindowedValue> wkvi = - CoderHelpers.fromByteArray(input, wkviCoder); - AccumT va = - keyed.addInput(wkva.getValue().getKey(), wkva.getValue().getValue(), - wkvi.getValue().getValue()); - wkva = WindowedValue.of(KV.of(wkva.getValue().getKey(), va), wkva.getTimestamp(), - wkva.getWindows(), wkva.getPane()); - return CoderHelpers.toByteArray(wkva, wkvaCoder); - } - }, - new Function2*/ byte[], - /*KV*/ byte[], - /*KV*/ byte[]>() { - @Override - public /*KV*/ byte[] call(/*KV*/ byte[] acc1, - /*KV*/ byte[] acc2) { - WindowedValue> wkva1 = - CoderHelpers.fromByteArray(acc1, wkvaCoder); - WindowedValue> wkva2 = - CoderHelpers.fromByteArray(acc2, wkvaCoder); - AccumT va = keyed.mergeAccumulators(wkva1.getValue().getKey(), - // don't use Guava's ImmutableList.of as values may be null - Collections.unmodifiableList(Arrays.asList(wkva1.getValue().getValue(), - wkva2.getValue().getValue()))); - WindowedValue> wkva = - WindowedValue.of(KV.of(wkva1.getValue().getKey(), - va), wkva1.getTimestamp(), wkva1.getWindows(), wkva1.getPane()); - return CoderHelpers.toByteArray(wkva, wkvaCoder); - } - }); - - JavaPairRDD, WindowedValue> extracted = accumulatedBytes - .mapToPair(CoderHelpers.fromByteFunction(wkCoder, wkvaCoder)) - .mapValues( - new Function>, WindowedValue>() { - @Override - public WindowedValue call(WindowedValue> acc) { - return WindowedValue.of(keyed.extractOutput(acc.getValue().getKey(), - acc.getValue().getValue()), acc.getTimestamp(), - acc.getWindows(), acc.getPane()); - } - }); + @SuppressWarnings("unchecked") + JavaRDD>> inRdd = + (JavaRDD>>) context.getInputRDD(transform); - context.setOutputRDD(transform, - fromPair(extracted) - .map(new Function, WindowedValue>, - WindowedValue>>() { - @Override - public WindowedValue> call(KV, - WindowedValue> kwvo) - throws Exception { - WindowedValue wvo = kwvo.getValue(); - KV kvo = KV.of(kwvo.getKey().getValue(), wvo.getValue()); - return WindowedValue.of(kvo, wvo.getTimestamp(), wvo.getWindows(), wvo.getPane()); - } - })); + context.setOutputRDD(transform, GroupCombineFunctions.combinePerKey(inRdd, keyed, wkCoder, + wkviCoder, wkvaCoder)); } }; } - private static final class KVFunction - implements Function>>, - WindowedValue>> { - private final Combine.KeyedCombineFn keyed; - - KVFunction(Combine.KeyedCombineFn keyed) { - this.keyed = keyed; - } - - @Override - public WindowedValue> call(WindowedValue>> windowedKv) - throws Exception { - KV> kv = windowedKv.getValue(); - return WindowedValue.of(KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())), - windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); - } - } - - private static JavaPairRDD toPair(JavaRDDLike, ?> rdd) { - return rdd.mapToPair(new PairFunction, K, V>() { - @Override - public Tuple2 call(KV kv) { - return new Tuple2<>(kv.getKey(), kv.getValue()); - } - }); - } - - private static JavaRDDLike, ?> fromPair(JavaPairRDD rdd) { - return rdd.map(new Function, KV>() { - @Override - public KV call(Tuple2 t2) { - return KV.of(t2._1(), t2._2()); - } - }); - } - private static TransformEvaluator> parDo() { return new TransformEvaluator>() { @Override public void evaluate(ParDo.Bound transform, EvaluationContext context) { - DoFnFunction dofn = - new DoFnFunction<>(transform.getFn(), - context.getRuntimeContext(), - getSideInputs(transform.getSideInputs(), context)); @SuppressWarnings("unchecked") JavaRDDLike, ?> inRDD = (JavaRDDLike, ?>) context.getInputRDD(transform); - context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + Accumulator accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); + Map, BroadcastHelper> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); + context.setOutputRDD(transform, + inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(), + context.getRuntimeContext(), sideInputs))); } }; } - private static final FieldGetter MULTIDO_FG = new FieldGetter(ParDo.BoundMulti.class); - - private static TransformEvaluator> multiDo() { + private static TransformEvaluator> + multiDo() { return new TransformEvaluator>() { @Override public void evaluate(ParDo.BoundMulti transform, EvaluationContext context) { - TupleTag mainOutputTag = MULTIDO_FG.get("mainOutputTag", transform); - MultiDoFnFunction multifn = new MultiDoFnFunction<>( - transform.getFn(), - context.getRuntimeContext(), - mainOutputTag, - getSideInputs(transform.getSideInputs(), context)); - @SuppressWarnings("unchecked") JavaRDDLike, ?> inRDD = (JavaRDDLike, ?>) context.getInputRDD(transform); + Accumulator accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); JavaPairRDD, WindowedValue> all = inRDD - .mapPartitionsToPair(multifn) - .cache(); - + .mapPartitionsToPair( + new MultiDoFnFunction<>(accum, transform.getFn(), context.getRuntimeContext(), + transform.getMainOutputTag(), TranslationUtils.getSideInputs( + transform.getSideInputs(), context))) + .cache(); PCollectionTuple pct = context.getOutput(transform); for (Map.Entry, PCollection> e : pct.getAll().entrySet()) { @SuppressWarnings("unchecked") JavaPairRDD, WindowedValue> filtered = - all.filter(new TupleTagFilter(e.getKey())); + all.filter(new TranslationUtils.TupleTagFilter(e.getKey())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaRDD> values = @@ -753,22 +493,17 @@ public final class TransformTranslator { JavaRDDLike, ?> inRDD = (JavaRDDLike, ?>) context.getInputRDD(transform); - @SuppressWarnings("unchecked") - WindowFn windowFn = (WindowFn) transform.getWindowFn(); - - // Avoid running assign windows if both source and destination are global window - // or if the user has not specified the WindowFn (meaning they are just messing - // with triggering or allowed lateness) - if (windowFn == null - || (context.getInput(transform).getWindowingStrategy().getWindowFn() - instanceof GlobalWindows - && windowFn instanceof GlobalWindows)) { + if (TranslationUtils.skipAssignWindows(transform, context)) { context.setOutputRDD(transform, inRDD); } else { + @SuppressWarnings("unchecked") + WindowFn windowFn = (WindowFn) transform.getWindowFn(); OldDoFn addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - DoFnFunction dofn = - new DoFnFunction<>(addWindowsDoFn, context.getRuntimeContext(), null); - context.setOutputRDD(transform, inRDD.mapPartitions(dofn)); + Accumulator accum = + AccumulatorSingleton.getInstance(context.getSparkContext()); + context.setOutputRDD(transform, + inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn, + context.getRuntimeContext(), null))); } } }; @@ -822,42 +557,6 @@ public final class TransformTranslator { }; } - private static final class TupleTagFilter - implements Function, WindowedValue>, Boolean> { - - private final TupleTag tag; - - private TupleTagFilter(TupleTag tag) { - this.tag = tag; - } - - @Override - public Boolean call(Tuple2, WindowedValue> input) { - return tag.equals(input._1()); - } - } - - private static Map, BroadcastHelper> getSideInputs( - List> views, - EvaluationContext context) { - if (views == null) { - return ImmutableMap.of(); - } else { - Map, BroadcastHelper> sideInputs = Maps.newHashMap(); - for (PCollectionView view : views) { - Iterable> collectionView = context.getPCollectionView(view); - Coder>> coderInternal = view.getCoderInternal(); - @SuppressWarnings("unchecked") - BroadcastHelper helper = - BroadcastHelper.create((Iterable>) collectionView, coderInternal); - //broadcast side inputs - helper.broadcast(context.getSparkContext()); - sideInputs.put(view.getTagInternal(), helper); - } - return sideInputs; - } - } - private static final Map, TransformEvaluator> EVALUATORS = Maps .newHashMap(); @@ -870,7 +569,7 @@ public final class TransformTranslator { EVALUATORS.put(HadoopIO.Write.Bound.class, writeHadoop()); EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); - EVALUATORS.put(GroupByKeyOnly.class, gbk()); + EVALUATORS.put(GroupByKeyOnly.class, gbko()); EVALUATORS.put(GroupAlsoByWindow.class, gabw()); EVALUATORS.put(Combine.GroupedValues.class, grouped()); EVALUATORS.put(Combine.Globally.class, combineGlobally()); @@ -883,17 +582,6 @@ public final class TransformTranslator { EVALUATORS.put(Window.Bound.class, window()); } - public static > TransformEvaluator - getTransformEvaluator(Class clazz) { - @SuppressWarnings("unchecked") - TransformEvaluator transform = - (TransformEvaluator) EVALUATORS.get(clazz); - if (transform == null) { - throw new IllegalStateException("No TransformEvaluator registered for " + clazz); - } - return transform; - } - /** * Translator matches Beam transformation with the appropriate evaluator. */ @@ -905,17 +593,20 @@ public final class TransformTranslator { } @Override - public > TransformEvaluator translate( - Class clazz) { - return getTransformEvaluator(clazz); + public > TransformEvaluator + translateBounded (Class clazz) { + @SuppressWarnings("unchecked") TransformEvaluator transformEvaluator = + (TransformEvaluator) EVALUATORS.get(clazz); + checkState(transformEvaluator != null, + "No TransformEvaluator registered for BOUNDED transform %s", clazz); + return transformEvaluator; } - } - private static class InMemoryStateInternalsFactory implements StateInternalsFactory, - Serializable { @Override - public StateInternals stateInternalsForKey(K key) { - return InMemoryStateInternals.forKey(key); + public > TransformEvaluator + translateUnbounded(Class clazz) { + throw new IllegalStateException("TransformTranslator used in a batch pipeline only " + + "supports BOUNDED transforms."); } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java new file mode 100644 index 0000000..9b156fe --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.state.InMemoryStateInternals; +import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.StateInternalsFactory; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; + +import scala.Tuple2; + +/** + * A set of utilities to help translating Beam transformations into Spark transformations. + */ +public final class TranslationUtils { + + private TranslationUtils() { + } + + /** + * In-memory state internals factory. + * + * @param State key type. + */ + static class InMemoryStateInternalsFactory implements StateInternalsFactory, + Serializable { + @Override + public StateInternals stateInternalsForKey(K key) { + return InMemoryStateInternals.forKey(key); + } + } + + /** + * A {@link Combine.GroupedValues} function applied to grouped KVs. + * + * @param Grouped key type. + * @param Grouped values type. + * @param Output type. + */ + public static class CombineGroupedValues implements + Function>>, WindowedValue>> { + private final Combine.KeyedCombineFn keyed; + + public CombineGroupedValues(Combine.GroupedValues transform) { + //noinspection unchecked + keyed = (Combine.KeyedCombineFn) transform.getFn(); + } + + @Override + public WindowedValue> call(WindowedValue>> windowedKv) + throws Exception { + KV> kv = windowedKv.getValue(); + return WindowedValue.of(KV.of(kv.getKey(), keyed.apply(kv.getKey(), kv.getValue())), + windowedKv.getTimestamp(), windowedKv.getWindows(), windowedKv.getPane()); + } + } + + /** + * Checks if the window transformation should be applied or skipped. + * + *

+ * Avoid running assign windows if both source and destination are global window + * or if the user has not specified the WindowFn (meaning they are just messing + * with triggering or allowed lateness). + *

+ * + * @param transform The {@link Window.Bound} transformation. + * @param context The {@link EvaluationContext}. + * @param PCollection type. + * @param {@link BoundedWindow} type. + * @return if to apply the transformation. + */ + public static boolean + skipAssignWindows(Window.Bound transform, EvaluationContext context) { + @SuppressWarnings("unchecked") + WindowFn windowFn = (WindowFn) transform.getWindowFn(); + return windowFn == null + || (context.getInput(transform).getWindowingStrategy().getWindowFn() + instanceof GlobalWindows + && windowFn instanceof GlobalWindows); + } + + /** Transform a pair stream into a value stream. */ + public static JavaDStream dStreamValues(JavaPairDStream pairDStream) { + return pairDStream.map(new Function, T2>() { + @Override + public T2 call(Tuple2 v1) throws Exception { + return v1._2(); + } + }); + } + + /** {@link KV} to pair function. */ + static PairFunction, K, V> toPairFunction() { + return new PairFunction, K, V>() { + @Override + public Tuple2 call(KV kv) { + return new Tuple2<>(kv.getKey(), kv.getValue()); + } + }; + } + + /** A pair to {@link KV} function . */ + static Function, KV> fromPairFunction() { + return new Function, KV>() { + @Override + public KV call(Tuple2 t2) { + return KV.of(t2._1(), t2._2()); + } + }; + } + + /** + * A utility class to filter {@link TupleTag}s. + * + * @param TupleTag type. + */ + public static final class TupleTagFilter + implements Function, WindowedValue>, Boolean> { + + private final TupleTag tag; + + public TupleTagFilter(TupleTag tag) { + this.tag = tag; + } + + @Override + public Boolean call(Tuple2, WindowedValue> input) { + return tag.equals(input._1()); + } + } + + /*** + * Create SideInputs as Broadcast variables. + * + * @param views The {@link PCollectionView}s. + * @param context The {@link EvaluationContext}. + * @return a map of tagged {@link BroadcastHelper}s. + */ + public static Map, BroadcastHelper> getSideInputs(List> views, + EvaluationContext context) { + if (views == null) { + return ImmutableMap.of(); + } else { + Map, BroadcastHelper> sideInputs = Maps.newHashMap(); + for (PCollectionView view : views) { + Iterable> collectionView = context.getPCollectionView(view); + Coder>> coderInternal = view.getCoderInternal(); + @SuppressWarnings("unchecked") + BroadcastHelper helper = + BroadcastHelper.create((Iterable>) collectionView, coderInternal); + //broadcast side inputs + helper.broadcast(context.getSparkContext()); + sideInputs.put(view.getTagInternal(), helper); + } + return sideInputs; + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java new file mode 100644 index 0000000..b7a407c --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation.streaming; + +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.Arrays; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkRunner; +import org.apache.beam.runners.spark.translation.SparkContextFactory; +import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; +import org.apache.beam.runners.spark.translation.TransformTranslator; +import org.apache.beam.sdk.Pipeline; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * A {@link JavaStreamingContext} factory for resilience. + * @see how-to-configure-checkpointing + */ +public class SparkRunnerStreamingContextFactory implements JavaStreamingContextFactory { + private static final Logger LOG = + LoggerFactory.getLogger(SparkRunnerStreamingContextFactory.class); + private static final Iterable KNOWN_RELIABLE_FS = Arrays.asList("hdfs", "s3", "gs"); + + private final Pipeline pipeline; + private final SparkPipelineOptions options; + + public SparkRunnerStreamingContextFactory(Pipeline pipeline, SparkPipelineOptions options) { + this.pipeline = pipeline; + this.options = options; + } + + private StreamingEvaluationContext ctxt; + + @Override + public JavaStreamingContext create() { + LOG.info("Creating a new Spark Streaming Context"); + + SparkPipelineTranslator translator = new StreamingTransformTranslator.Translator( + new TransformTranslator.Translator()); + Duration batchDuration = new Duration(options.getBatchIntervalMillis()); + LOG.info("Setting Spark streaming batchDuration to {} msec", batchDuration.milliseconds()); + + JavaSparkContext jsc = SparkContextFactory.getSparkContext(options); + JavaStreamingContext jssc = new JavaStreamingContext(jsc, batchDuration); + ctxt = new StreamingEvaluationContext(jsc, pipeline, jssc, + options.getTimeout()); + pipeline.traverseTopologically(new SparkRunner.Evaluator(translator, ctxt)); + ctxt.computeOutputs(); + + // set checkpoint dir. + String checkpointDir = options.getCheckpointDir(); + LOG.info("Checkpoint dir set to: {}", checkpointDir); + try { + // validate checkpoint dir and warn if not of a known durable filesystem. + URL checkpointDirUrl = new URL(checkpointDir); + if (!Iterables.any(KNOWN_RELIABLE_FS, Predicates.equalTo(checkpointDirUrl.getProtocol()))) { + LOG.warn("Checkpoint dir URL {} does not match a reliable filesystem, in case of failures " + + "this job may not recover properly or even at all.", checkpointDirUrl); + } + } catch (MalformedURLException e) { + throw new RuntimeException("Failed to form checkpoint dir URL. CheckpointDir should be in " + + "the form of hdfs:///path/to/dir or other reliable fs protocol, " + + "or file:///path/to/dir for local mode.", e); + } + jssc.checkpoint(checkpointDir); + + return jssc; + } + + public StreamingEvaluationContext getCtxt() { + return ctxt; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java index 2e4da44..5a43c55 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingEvaluationContext.java @@ -18,14 +18,18 @@ package org.apache.beam.runners.spark.translation.streaming; +import com.google.common.collect.Iterables; + import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.LinkedBlockingQueue; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; +import org.apache.beam.runners.spark.translation.WindowingHelpers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -82,11 +86,17 @@ public class StreamingEvaluationContext extends EvaluationContext { @SuppressWarnings("unchecked") JavaDStream> getDStream() { if (dStream == null) { - // create the DStream from values + WindowedValue.ValueOnlyWindowedValueCoder windowCoder = + WindowedValue.getValueOnlyCoder(coder); + // create the DStream from queue Queue>> rddQueue = new LinkedBlockingQueue<>(); for (Iterable v : values) { - setOutputRDDFromValues(currentTransform.getTransform(), v, coder); - rddQueue.offer((JavaRDD>) getOutputRDD(currentTransform.getTransform())); + Iterable> windowedValues = + Iterables.transform(v, WindowingHelpers.windowValueFunction()); + JavaRDD> rdd = getSparkContext().parallelize( + CoderHelpers.toByteArrays(windowedValues, windowCoder)).map( + CoderHelpers.fromByteFunction(windowCoder)); + rddQueue.offer(rdd); } // create dstream from queue, one at a time, no defaults // mainly for unit test so no reason to have this configurable @@ -102,7 +112,10 @@ public class StreamingEvaluationContext extends EvaluationContext { } void setStream(PTransform transform, JavaDStream> dStream) { - PValue pvalue = (PValue) getOutput(transform); + setStream((PValue) getOutput(transform), dStream); + } + + void setStream(PValue pvalue, JavaDStream> dStream) { DStreamHolder dStreamHolder = new DStreamHolder<>(dStream); pstreams.put(pvalue, dStreamHolder); leafStreams.add(dStreamHolder); @@ -110,6 +123,10 @@ public class StreamingEvaluationContext extends EvaluationContext { boolean hasStream(PTransform transform) { PValue pvalue = (PValue) getInput(transform); + return hasStream(pvalue); + } + + boolean hasStream(PValue pvalue) { return pstreams.containsKey(pvalue); } @@ -141,19 +158,23 @@ public class StreamingEvaluationContext extends EvaluationContext { @Override public void computeOutputs() { + super.computeOutputs(); // in case the pipeline contains bounded branches as well. for (DStreamHolder streamHolder : leafStreams) { computeOutput(streamHolder); - } + } // force a DStream action } private static void computeOutput(DStreamHolder streamHolder) { - streamHolder.getDStream().foreachRDD(new VoidFunction>>() { + JavaDStream> dStream = streamHolder.getDStream(); + // cache in DStream level not RDD + // because there could be a difference in StorageLevel if the DStream is windowed. + dStream.dstream().cache(); + dStream.foreachRDD(new VoidFunction>>() { @Override public void call(JavaRDD> rdd) throws Exception { - rdd.rdd().cache(); rdd.count(); } - }); // force a DStream action + }); } @Override @@ -163,8 +184,9 @@ public class StreamingEvaluationContext extends EvaluationContext { } else { jssc.awaitTermination(); } - //TODO: stop gracefully ? - jssc.stop(false, false); + // stop streaming context gracefully, so checkpointing (and other computations) get to + // finish before shutdown. + jssc.stop(false, true); state = State.DONE; super.close(); } @@ -197,7 +219,7 @@ public class StreamingEvaluationContext extends EvaluationContext { } @Override - protected void setCurrentTransform(AppliedPTransform transform) { + public void setCurrentTransform(AppliedPTransform transform) { super.setCurrentTransform(transform); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index c55be3d..64ddc57 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -17,53 +17,68 @@ */ package org.apache.beam.runners.spark.translation.streaming; -import com.google.common.collect.Lists; +import static com.google.common.base.Preconditions.checkState; + import com.google.common.collect.Maps; -import com.google.common.collect.Sets; -import com.google.common.reflect.TypeToken; -import java.lang.reflect.ParameterizedType; -import java.lang.reflect.Type; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import kafka.serializer.Decoder; import org.apache.beam.runners.core.AssignWindowsDoFn; +import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly; +import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow; +import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly; +import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.ConsoleIO; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.io.KafkaIO; -import org.apache.beam.runners.spark.io.hadoop.HadoopIO; import org.apache.beam.runners.spark.translation.DoFnFunction; import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.GroupCombineFunctions; +import org.apache.beam.runners.spark.translation.MultiDoFnFunction; import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; +import org.apache.beam.runners.spark.translation.SparkRuntimeContext; import org.apache.beam.runners.spark.translation.TransformEvaluator; +import org.apache.beam.runners.spark.translation.TranslationUtils; import org.apache.beam.runners.spark.translation.WindowingHelpers; +import org.apache.beam.runners.spark.util.BroadcastHelper; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.io.AvroIO; -import org.apache.beam.sdk.io.TextIO; -import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; -import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.spark.Accumulator; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaDStreamLike; +import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaPairInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.kafka.KafkaUtils; + import scala.Tuple2; @@ -114,19 +129,6 @@ public final class StreamingTransformTranslator { }; } - private static TransformEvaluator> create() { - return new TransformEvaluator>() { - @SuppressWarnings("unchecked") - @Override - public void evaluate(Create.Values transform, EvaluationContext context) { - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - Iterable elems = transform.getElements(); - Coder coder = sec.getOutput(transform).getCoder(); - sec.setDStreamFromQueue(transform, Collections.singletonList(elems), coder); - } - }; - } - private static TransformEvaluator> createFromQueue() { return new TransformEvaluator>() { @Override @@ -146,173 +148,325 @@ public final class StreamingTransformTranslator { public void evaluate(Flatten.FlattenPCollectionList transform, EvaluationContext context) { StreamingEvaluationContext sec = (StreamingEvaluationContext) context; PCollectionList pcs = sec.getInput(transform); - JavaDStream> first = - (JavaDStream>) sec.getStream(pcs.get(0)); - List>> rest = Lists.newArrayListWithCapacity(pcs.size() - 1); - for (int i = 1; i < pcs.size(); i++) { - rest.add((JavaDStream>) sec.getStream(pcs.get(i))); + // since this is a streaming pipeline, at least one of the PCollections to "flatten" are + // unbounded, meaning it represents a DStream. + // So we could end up with an unbounded unified DStream. + final List>> rdds = new ArrayList<>(); + final List>> dStreams = new ArrayList<>(); + for (PCollection pcol: pcs.getAll()) { + if (sec.hasStream(pcol)) { + dStreams.add((JavaDStream>) sec.getStream(pcol)); + } else { + rdds.add((JavaRDD>) context.getRDD(pcol)); + } + } + // start by unifying streams into a single stream. + JavaDStream> unifiedStreams = + sec.getStreamingContext().union(dStreams.remove(0), dStreams); + // now unify in RDDs. + if (rdds.size() > 0) { + JavaDStream> joined = unifiedStreams.transform( + new Function>, JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> streamRdd) + throws Exception { + return new JavaSparkContext(streamRdd.context()).union(streamRdd, rdds); + } + }); + sec.setStream(transform, joined); + } else { + sec.setStream(transform, unifiedStreams); } - JavaDStream> dstream = sec.getStreamingContext().union(first, rest); - sec.setStream(transform, dstream); } }; } - private static > TransformEvaluator rddTransform( - final SparkPipelineTranslator rddTranslator) { - return new TransformEvaluator() { - @SuppressWarnings("unchecked") + private static TransformEvaluator> window() { + return new TransformEvaluator>() { @Override - public void evaluate(TransformT transform, EvaluationContext context) { - TransformEvaluator rddEvaluator = - rddTranslator.translate((Class) transform.getClass()); - + public void evaluate(Window.Bound transform, EvaluationContext context) { StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - if (sec.hasStream(transform)) { - JavaDStreamLike, ?, JavaRDD>> dStream = - (JavaDStreamLike, ?, JavaRDD>>) - sec.getStream(transform); - - sec.setStream(transform, dStream - .transform(new RDDTransform<>(sec, rddEvaluator, transform))); + @SuppressWarnings("unchecked") + WindowFn windowFn = (WindowFn) transform.getWindowFn(); + @SuppressWarnings("unchecked") + JavaDStream> dStream = + (JavaDStream>) sec.getStream(transform); + if (windowFn instanceof FixedWindows) { + Duration windowDuration = Durations.milliseconds(((FixedWindows) windowFn).getSize() + .getMillis()); + sec.setStream(transform, dStream.window(windowDuration)); + } else if (windowFn instanceof SlidingWindows) { + Duration windowDuration = Durations.milliseconds(((SlidingWindows) windowFn).getSize() + .getMillis()); + Duration slideDuration = Durations.milliseconds(((SlidingWindows) windowFn).getPeriod() + .getMillis()); + sec.setStream(transform, dStream.window(windowDuration, slideDuration)); + } + //--- then we apply windowing to the elements + @SuppressWarnings("unchecked") + JavaDStream> dStream2 = + (JavaDStream>) sec.getStream(transform); + if (TranslationUtils.skipAssignWindows(transform, context)) { + sec.setStream(transform, dStream2); } else { - // if the transformation requires direct access to RDD (not in stream) - // this is used for "fake" transformations like with PAssert - rddEvaluator.evaluate(transform, context); + final OldDoFn addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + JavaDStream> outStream = dStream2.transform( + new Function>, JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) throws Exception { + final Accumulator accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return rdd.mapPartitions( + new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null)); + } + }); + sec.setStream(transform, outStream); } } }; } - /** - * RDD transform function If the transformation function doesn't have an input, create a fake one - * as an empty RDD. - * - * @param PTransform type - */ - private static final class RDDTransform> - implements Function>, JavaRDD>> { - - private final StreamingEvaluationContext context; - private final AppliedPTransform appliedPTransform; - private final TransformEvaluator rddEvaluator; - private final TransformT transform; - - - private RDDTransform(StreamingEvaluationContext context, - TransformEvaluator rddEvaluator, - TransformT transform) { - this.context = context; - this.appliedPTransform = context.getCurrentTransform(); - this.rddEvaluator = rddEvaluator; - this.transform = transform; - } + private static TransformEvaluator> gbko() { + return new TransformEvaluator>() { + @Override + public void evaluate(GroupByKeyOnly transform, EvaluationContext context) { + StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - @Override - @SuppressWarnings("unchecked") - public JavaRDD> - call(JavaRDD> rdd) throws Exception { - AppliedPTransform existingAPT = context.getCurrentTransform(); - context.setCurrentTransform(appliedPTransform); - context.setInputRDD(transform, rdd); - rddEvaluator.evaluate(transform, context); - if (!context.hasOutputRDD(transform)) { - // fake RDD as output - context.setOutputRDD(transform, - context.getSparkContext().>emptyRDD()); + @SuppressWarnings("unchecked") + JavaDStream>> dStream = + (JavaDStream>>) sec.getStream(transform); + + @SuppressWarnings("unchecked") + final KvCoder coder = (KvCoder) sec.getInput(transform).getCoder(); + + JavaDStream>>> outStream = + dStream.transform(new Function>>, + JavaRDD>>>>() { + @Override + public JavaRDD>>> call( + JavaRDD>> rdd) throws Exception { + return GroupCombineFunctions.groupByKeyOnly(rdd, coder); + } + }); + sec.setStream(transform, outStream); } - JavaRDD> outRDD = - (JavaRDD>) context.getOutputRDD(transform); - context.setCurrentTransform(existingAPT); - return outRDD; - } + }; } - @SuppressWarnings("unchecked") - private static > TransformEvaluator foreachRDD( - final SparkPipelineTranslator rddTranslator) { - return new TransformEvaluator() { + private static + TransformEvaluator> gabw() { + return new TransformEvaluator>() { @Override - public void evaluate(TransformT transform, EvaluationContext context) { - TransformEvaluator rddEvaluator = - rddTranslator.translate((Class) transform.getClass()); + public void evaluate(final GroupAlsoByWindow transform, EvaluationContext context) { + final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + @SuppressWarnings("unchecked") + JavaDStream>>>> dStream = + (JavaDStream>>>>) + sec.getStream(transform); - StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - if (sec.hasStream(transform)) { - JavaDStreamLike, ?, JavaRDD>> dStream = - (JavaDStreamLike, ?, JavaRDD>>) - sec.getStream(transform); + @SuppressWarnings("unchecked") + final KvCoder>> inputKvCoder = + (KvCoder>>) sec.getInput(transform).getCoder(); + + JavaDStream>>> outStream = + dStream.transform(new Function>>>>, JavaRDD>>>>() { + @Override + public JavaRDD>>> call(JavaRDD>>>> rdd) throws Exception { + final Accumulator accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return GroupCombineFunctions.groupAlsoByWindow(rdd, transform, runtimeContext, + accum, inputKvCoder); + } + }); + sec.setStream(transform, outStream); + } + }; + } - dStream.foreachRDD(new RDDOutputOperator<>(sec, rddEvaluator, transform)); - } else { - rddEvaluator.evaluate(transform, context); - } + private static TransformEvaluator> + grouped() { + return new TransformEvaluator>() { + @Override + public void evaluate(Combine.GroupedValues transform, + EvaluationContext context) { + StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + @SuppressWarnings("unchecked") + JavaDStream>>> dStream = + (JavaDStream>>>) sec.getStream(transform); + sec.setStream(transform, dStream.map( + new TranslationUtils.CombineGroupedValues<>(transform))); } }; } - /** - * RDD output function. - * - * @param PTransform type - */ - private static final class RDDOutputOperator> - implements VoidFunction>> { + private static TransformEvaluator> + combineGlobally() { + return new TransformEvaluator>() { - private final StreamingEvaluationContext context; - private final AppliedPTransform appliedPTransform; - private final TransformEvaluator rddEvaluator; - private final TransformT transform; + @Override + public void evaluate(Combine.Globally transform, EvaluationContext context) { + StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + @SuppressWarnings("unchecked") + final Combine.CombineFn globally = + (Combine.CombineFn) transform.getFn(); + @SuppressWarnings("unchecked") + JavaDStream> dStream = + (JavaDStream>) sec.getStream(transform); + + final Coder iCoder = sec.getInput(transform).getCoder(); + final Coder oCoder = sec.getOutput(transform).getCoder(); + final Coder aCoder; + try { + aCoder = globally.getAccumulatorCoder(sec.getPipeline().getCoderRegistry(), iCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } - private RDDOutputOperator(StreamingEvaluationContext context, - TransformEvaluator rddEvaluator, TransformT transform) { - this.context = context; - this.appliedPTransform = context.getCurrentTransform(); - this.rddEvaluator = rddEvaluator; - this.transform = transform; - } + JavaDStream> outStream = dStream.transform( + new Function>, JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) + throws Exception { + JavaRDD outRdd = new JavaSparkContext(rdd.context()).parallelize( + // don't use Guava's ImmutableList.of as output may be null + CoderHelpers.toByteArrays(Collections.singleton( + GroupCombineFunctions.combineGlobally(rdd, globally, iCoder, aCoder)), oCoder)); + return outRdd.map(CoderHelpers.fromByteFunction(oCoder)).map( + WindowingHelpers.windowFunction()); + } + }); - @Override - @SuppressWarnings("unchecked") - public void call(JavaRDD> rdd) throws Exception { - AppliedPTransform existingAPT = context.getCurrentTransform(); - context.setCurrentTransform(appliedPTransform); - context.setInputRDD(transform, rdd); - rddEvaluator.evaluate(transform, context); - context.setCurrentTransform(existingAPT); - } + sec.setStream(transform, outStream); + } + }; } - private static TransformEvaluator> window() { - return new TransformEvaluator>() { + private static + TransformEvaluator> combinePerKey() { + return new TransformEvaluator>() { @Override - public void evaluate(Window.Bound transform, EvaluationContext context) { + public void evaluate(Combine.PerKey + transform, EvaluationContext context) { StreamingEvaluationContext sec = (StreamingEvaluationContext) context; - WindowFn windowFn = transform.getWindowFn(); @SuppressWarnings("unchecked") - JavaDStream> dStream = - (JavaDStream>) sec.getStream(transform); - if (windowFn instanceof FixedWindows) { - Duration windowDuration = Durations.milliseconds(((FixedWindows) windowFn).getSize() - .getMillis()); - sec.setStream(transform, dStream.window(windowDuration)); - } else if (windowFn instanceof SlidingWindows) { - Duration windowDuration = Durations.milliseconds(((SlidingWindows) windowFn).getSize() - .getMillis()); - Duration slideDuration = Durations.milliseconds(((SlidingWindows) windowFn).getPeriod() - .getMillis()); - sec.setStream(transform, dStream.window(windowDuration, slideDuration)); + final Combine.KeyedCombineFn keyed = + (Combine.KeyedCombineFn) transform.getFn(); + @SuppressWarnings("unchecked") + JavaDStream>> dStream = + (JavaDStream>>) sec.getStream(transform); + + @SuppressWarnings("unchecked") + KvCoder inputCoder = (KvCoder) sec.getInput(transform).getCoder(); + Coder keyCoder = inputCoder.getKeyCoder(); + Coder viCoder = inputCoder.getValueCoder(); + Coder vaCoder; + try { + vaCoder = keyed.getAccumulatorCoder( + context.getPipeline().getCoderRegistry(), keyCoder, viCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); } - //--- then we apply windowing to the elements - OldDoFn addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - DoFnFunction dofn = new DoFnFunction<>(addWindowsDoFn, - ((StreamingEvaluationContext) context).getRuntimeContext(), null); + Coder> kviCoder = KvCoder.of(keyCoder, viCoder); + Coder> kvaCoder = KvCoder.of(keyCoder, vaCoder); + //-- windowed coders + final WindowedValue.FullWindowedValueCoder wkCoder = + WindowedValue.FullWindowedValueCoder.of(keyCoder, + sec.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder> wkviCoder = + WindowedValue.FullWindowedValueCoder.of(kviCoder, + sec.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + final WindowedValue.FullWindowedValueCoder> wkvaCoder = + WindowedValue.FullWindowedValueCoder.of(kvaCoder, + sec.getInput(transform).getWindowingStrategy().getWindowFn().windowCoder()); + + JavaDStream>> outStream = + dStream.transform(new Function>>, + JavaRDD>>>() { + @Override + public JavaRDD>> call( + JavaRDD>> rdd) throws Exception { + return GroupCombineFunctions.combinePerKey(rdd, keyed, wkCoder, wkviCoder, wkvaCoder); + } + }); + + sec.setStream(transform, outStream); + } + }; + } + + private static TransformEvaluator> parDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(final ParDo.Bound transform, + final EvaluationContext context) { + final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final Map, BroadcastHelper> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); @SuppressWarnings("unchecked") - JavaDStreamLike, ?, JavaRDD>> dstream = - (JavaDStreamLike, ?, JavaRDD>>) - sec.getStream(transform); - sec.setStream(transform, dstream.mapPartitions(dofn)); + JavaDStream> dStream = + (JavaDStream>) sec.getStream(transform); + + JavaDStream> outStream = + dStream.transform(new Function>, + JavaRDD>>() { + @Override + public JavaRDD> call(JavaRDD> rdd) throws + Exception { + final Accumulator accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return rdd.mapPartitions( + new DoFnFunction<>(accum, transform.getFn(), runtimeContext, sideInputs)); + } + }); + + sec.setStream(transform, outStream); + } + }; + } + + private static TransformEvaluator> + multiDo() { + return new TransformEvaluator>() { + @Override + public void evaluate(final ParDo.BoundMulti transform, + final EvaluationContext context) { + final StreamingEvaluationContext sec = (StreamingEvaluationContext) context; + final SparkRuntimeContext runtimeContext = sec.getRuntimeContext(); + final Map, BroadcastHelper> sideInputs = + TranslationUtils.getSideInputs(transform.getSideInputs(), context); + @SuppressWarnings("unchecked") + JavaDStream> dStream = + (JavaDStream>) sec.getStream(transform); + JavaPairDStream, WindowedValue> all = dStream.transformToPair( + new Function>, + JavaPairRDD, WindowedValue>>() { + @Override + public JavaPairRDD, WindowedValue> call( + JavaRDD> rdd) throws Exception { + final Accumulator accum = + AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, transform.getFn(), + runtimeContext, transform.getMainOutputTag(), sideInputs)); + } + }).cache(); + PCollectionTuple pct = sec.getOutput(transform); + for (Map.Entry, PCollection> e : pct.getAll().entrySet()) { + @SuppressWarnings("unchecked") + JavaPairDStream, WindowedValue> filtered = + all.filter(new TranslationUtils.TupleTagFilter(e.getKey())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaDStream> values = + (JavaDStream>) + (JavaDStream) TranslationUtils.dStreamValues(filtered); + sec.setStream(e.getValue(), values); + } } }; } @@ -321,79 +475,54 @@ public final class StreamingTransformTranslator { .newHashMap(); static { + EVALUATORS.put(GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly.class, gbko()); + EVALUATORS.put(GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow.class, gabw()); + EVALUATORS.put(Combine.GroupedValues.class, grouped()); + EVALUATORS.put(Combine.Globally.class, combineGlobally()); + EVALUATORS.put(Combine.PerKey.class, combinePerKey()); + EVALUATORS.put(ParDo.Bound.class, parDo()); + EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.QueuedValues.class, createFromQueue()); - EVALUATORS.put(Create.Values.class, create()); EVALUATORS.put(KafkaIO.Read.Unbound.class, kafka()); EVALUATORS.put(Window.Bound.class, window()); EVALUATORS.put(Flatten.FlattenPCollectionList.class, flattenPColl()); } - private static final Set> UNSUPPORTED_EVALUATORS = Sets - .newHashSet(); - - static { - //TODO - add support for the following - UNSUPPORTED_EVALUATORS.add(TextIO.Read.Bound.class); - UNSUPPORTED_EVALUATORS.add(TextIO.Write.Bound.class); - UNSUPPORTED_EVALUATORS.add(AvroIO.Read.Bound.class); - UNSUPPORTED_EVALUATORS.add(AvroIO.Write.Bound.class); - UNSUPPORTED_EVALUATORS.add(HadoopIO.Read.Bound.class); - UNSUPPORTED_EVALUATORS.add(HadoopIO.Write.Bound.class); - } - - @SuppressWarnings("unchecked") - private static > TransformEvaluator - getTransformEvaluator(Class clazz, SparkPipelineTranslator rddTranslator) { - TransformEvaluator transform = - (TransformEvaluator) EVALUATORS.get(clazz); - if (transform == null) { - if (UNSUPPORTED_EVALUATORS.contains(clazz)) { - throw new UnsupportedOperationException("Beam transformation " + clazz - .getCanonicalName() - + " is currently unsupported by the Spark streaming pipeline"); - } - // DStream transformations will transform an RDD into another RDD - // Actions will create output - // In Beam it depends on the PTransform's Input and Output class - Class pTOutputClazz = getPTransformOutputClazz(clazz); - if (PDone.class.equals(pTOutputClazz)) { - return foreachRDD(rddTranslator); - } else { - return rddTransform(rddTranslator); - } - } - return transform; - } - - private static > Class - getPTransformOutputClazz(Class clazz) { - Type[] types = ((ParameterizedType) clazz.getGenericSuperclass()).getActualTypeArguments(); - return TypeToken.of(clazz).resolveType(types[1]).getRawType(); - } - /** - * Translator matches Beam transformation with the appropriate Spark streaming evaluator. - * rddTranslator uses Spark evaluators in transform/foreachRDD to evaluate the transformation + * Translator matches Beam transformation with the appropriate evaluator. */ public static class Translator implements SparkPipelineTranslator { - private final SparkPipelineTranslator rddTranslator; + private final SparkPipelineTranslator batchTranslator; - public Translator(SparkPipelineTranslator rddTranslator) { - this.rddTranslator = rddTranslator; + Translator(SparkPipelineTranslator batchTranslator) { + this.batchTranslator = batchTranslator; } @Override public boolean hasTranslation(Class> clazz) { - // streaming includes rdd transformations as well - return EVALUATORS.containsKey(clazz) || rddTranslator.hasTranslation(clazz); + // streaming includes rdd/bounded transformations as well + return EVALUATORS.containsKey(clazz) || batchTranslator.hasTranslation(clazz); + } + + @Override + public > TransformEvaluator + translateBounded(Class clazz) { + TransformEvaluator transformEvaluator = batchTranslator.translateBounded(clazz); + checkState(transformEvaluator != null, + "No TransformEvaluator registered for BOUNDED transform %s", clazz); + return transformEvaluator; } @Override public > TransformEvaluator - translate(Class clazz) { - return getTransformEvaluator(clazz, rddTranslator); + translateUnbounded(Class clazz) { + @SuppressWarnings("unchecked") TransformEvaluator transformEvaluator = + (TransformEvaluator) EVALUATORS.get(clazz); + checkState(transformEvaluator != null, + "No TransformEvaluator registered for for UNBOUNDED transform %s", clazz); + return transformEvaluator; } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java index 5c13b80..0e742eb 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java @@ -52,8 +52,12 @@ public abstract class BroadcastHelper implements Serializable { public abstract T getValue(); + public abstract boolean isBroadcasted(); + public abstract void broadcast(JavaSparkContext jsc); + public abstract void unregister(); + /** * A {@link BroadcastHelper} that relies on the underlying * Spark serialization (Kryo) to broadcast values. This is appropriate when @@ -77,9 +81,20 @@ public abstract class BroadcastHelper implements Serializable { } @Override + public boolean isBroadcasted() { + return bcast != null; + } + + @Override public void broadcast(JavaSparkContext jsc) { this.bcast = jsc.broadcast(value); } + + @Override + public void unregister() { + this.bcast.destroy(); + this.bcast = null; + } } /** @@ -107,10 +122,21 @@ public abstract class BroadcastHelper implements Serializable { } @Override + public boolean isBroadcasted() { + return bcast != null; + } + + @Override public void broadcast(JavaSparkContext jsc) { this.bcast = jsc.broadcast(CoderHelpers.toByteArray(value, coder)); } + @Override + public void unregister() { + this.bcast.destroy(); + this.bcast = null; + } + private T deserialize() { T val; try {