Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id F33AD200BAA for ; Thu, 13 Oct 2016 02:38:43 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id F1E11160AEE; Thu, 13 Oct 2016 00:38:43 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id EAB94160ACA for ; Thu, 13 Oct 2016 02:38:41 +0200 (CEST) Received: (qmail 513 invoked by uid 500); 13 Oct 2016 00:38:41 -0000 Mailing-List: contact commits-help@beam.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@beam.incubator.apache.org Delivered-To: mailing list commits@beam.incubator.apache.org Received: (qmail 504 invoked by uid 99); 13 Oct 2016 00:38:41 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd4-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 13 Oct 2016 00:38:40 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd4-us-west.apache.org (ASF Mail Server at spamd4-us-west.apache.org) with ESMTP id 71D98C0C5C for ; Thu, 13 Oct 2016 00:38:40 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd4-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -6.209 X-Spam-Level: X-Spam-Status: No, score=-6.209 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-2.999, T_FILL_THIS_FORM_SHORT=0.01] autolearn=disabled Received: from mx1-lw-us.apache.org ([10.40.0.8]) by localhost (spamd4-us-west.apache.org [10.40.0.11]) (amavisd-new, port 10024) with ESMTP id mshHZnSuPIs4 for ; Thu, 13 Oct 2016 00:38:30 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-us.apache.org (ASF Mail Server at mx1-lw-us.apache.org) with SMTP id 892BD5F201 for ; Thu, 13 Oct 2016 00:38:29 +0000 (UTC) Received: (qmail 203 invoked by uid 99); 13 Oct 2016 00:38:28 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 13 Oct 2016 00:38:28 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id A4CF6E00A4; Thu, 13 Oct 2016 00:38:28 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: bchambers@apache.org To: commits@beam.incubator.apache.org Date: Thu, 13 Oct 2016 00:38:29 -0000 Message-Id: <0a44c0a43b1a4f628eaa4f362de71e19@git.apache.org> In-Reply-To: <36b11612fed04cfbb66b56342e4f69d1@git.apache.org> References: <36b11612fed04cfbb66b56342e4f69d1@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [2/4] incubator-beam git commit: [BEAM-65] SplittableDoFn prototype. archived-at: Thu, 13 Oct 2016 00:38:44 -0000 http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java index 3eee74a..f671a67 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnAdapters.java @@ -25,6 +25,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowingInternals; @@ -37,8 +38,8 @@ import org.joda.time.Instant; /** * Utility class containing adapters to/from {@link DoFn} and {@link OldDoFn}. * - * @deprecated This class will go away when we start running {@link DoFn}'s directly (using - * {@link DoFnInvoker}) rather than via {@link OldDoFn}. + * @deprecated This class will go away when we start running {@link DoFn}'s directly (using {@link + * DoFnInvoker}) rather than via {@link OldDoFn}. */ @Deprecated public class DoFnAdapters { @@ -176,6 +177,18 @@ public class DoFnAdapters { } /** + * If the fn was created using {@link #toOldDoFn}, returns the original {@link DoFn}. Otherwise, + * returns {@code null}. + */ + public static DoFn getDoFn(OldDoFn fn) { + if (fn instanceof SimpleDoFnAdapter) { + return ((SimpleDoFnAdapter) fn).fn; + } else { + return null; + } + } + + /** * Wraps a {@link DoFn} that doesn't require access to {@link BoundedWindow} as an {@link * OldDoFn}. */ @@ -324,6 +337,11 @@ public class DoFnAdapters { public DoFn.OutputReceiver outputReceiver() { throw new UnsupportedOperationException("outputReceiver() exists only for testing"); } + + @Override + public RestrictionTracker restrictionTracker() { + throw new UnsupportedOperationException("This is a non-splittable DoFn"); + } } /** @@ -412,5 +430,10 @@ public class DoFnAdapters { public DoFn.OutputReceiver outputReceiver() { throw new UnsupportedOperationException("outputReceiver() exists only for testing"); } + + @Override + public RestrictionTracker restrictionTracker() { + throw new UnsupportedOperationException("This is a non-splittable DoFn"); + } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 11a4cbd..302bb02 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -46,7 +46,9 @@ import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingInternals; import org.apache.beam.sdk.util.state.InMemoryStateInternals; +import org.apache.beam.sdk.util.state.InMemoryTimerInternals; import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.TimerCallback; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; @@ -222,8 +224,11 @@ public class DoFnTester implements AutoCloseable { if (state == State.UNINITIALIZED) { initializeState(); } - TestContext context = createContext(fn); + TestContext context = createContext(fn); context.setupDelegateAggregators(); + // State and timer internals are per-bundle. + stateInternals = InMemoryStateInternals.forKey(new Object()); + timerInternals = new InMemoryTimerInternals(); try { fn.startBundle(context); } catch (UserCodeException e) { @@ -460,6 +465,35 @@ public class DoFnTester implements AutoCloseable { return extractAggregatorValue(agg.getName(), agg.getCombineFn()); } + private static TimerCallback collectInto(final List firedTimers) { + return new TimerCallback() { + @Override + public void onTimer(TimerInternals.TimerData timer) throws Exception { + firedTimers.add(timer); + } + }; + } + + public List advanceInputWatermark(Instant newWatermark) { + try { + final List firedTimers = new ArrayList<>(); + timerInternals.advanceInputWatermark(collectInto(firedTimers), newWatermark); + return firedTimers; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public List advanceProcessingTime(Instant newProcessingTime) { + try { + final List firedTimers = new ArrayList<>(); + timerInternals.advanceProcessingTime(collectInto(firedTimers), newProcessingTime); + return firedTimers; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private AggregateT extractAggregatorValue( String name, CombineFn combiner) { @SuppressWarnings("unchecked") @@ -476,41 +510,27 @@ public class DoFnTester implements AutoCloseable { return MoreObjects.firstNonNull(elems, Collections.>emptyList()); } - private TestContext createContext(OldDoFn fn) { - return new TestContext<>(fn, options, mainOutputTag, outputs, accumulators); + private TestContext createContext(OldDoFn fn) { + return new TestContext(); } - private static class TestContext extends OldDoFn.Context { - private final PipelineOptions opts; - private final TupleTag mainOutputTag; - private final Map, List>> outputs; - private final Map accumulators; - - public TestContext( - OldDoFn fn, - PipelineOptions opts, - TupleTag mainOutputTag, - Map, List>> outputs, - Map accumulators) { + private class TestContext extends OldDoFn.Context { + TestContext() { fn.super(); - this.opts = opts; - this.mainOutputTag = mainOutputTag; - this.outputs = outputs; - this.accumulators = accumulators; } @Override public PipelineOptions getPipelineOptions() { - return opts; + return options; } @Override - public void output(OutT output) { + public void output(OutputT output) { sideOutput(mainOutputTag, output); } @Override - public void outputWithTimestamp(OutT output, Instant timestamp) { + public void outputWithTimestamp(OutputT output, Instant timestamp) { sideOutputWithTimestamp(mainOutputTag, output, timestamp); } @@ -570,40 +590,27 @@ public class DoFnTester implements AutoCloseable { } } - private TestProcessContext createProcessContext( + private TestProcessContext createProcessContext( OldDoFn fn, TimestampedValue elem) { WindowedValue windowedValue = WindowedValue.timestampedValueInGlobalWindow( elem.getValue(), elem.getTimestamp()); - return new TestProcessContext<>(fn, - createContext(fn), - windowedValue, - mainOutputTag, - sideInputs); - } - - private static class TestProcessContext extends OldDoFn.ProcessContext { - private final TestContext context; - private final TupleTag mainOutputTag; - private final WindowedValue element; - private final Map, Map> sideInputs; - - private TestProcessContext( - OldDoFn fn, - TestContext context, - WindowedValue element, - TupleTag mainOutputTag, - Map, Map> sideInputs) { + return new TestProcessContext(windowedValue); + } + + private class TestProcessContext extends OldDoFn.ProcessContext { + private final TestContext context; + private final WindowedValue element; + + private TestProcessContext(WindowedValue element) { fn.super(); - this.context = context; + this.context = createContext(fn); this.element = element; - this.mainOutputTag = mainOutputTag; - this.sideInputs = sideInputs; } @Override - public InT element() { + public InputT element() { return element.getValue(); } @@ -638,10 +645,8 @@ public class DoFnTester implements AutoCloseable { } @Override - public WindowingInternals windowingInternals() { - return new WindowingInternals() { - StateInternals stateInternals = InMemoryStateInternals.forKey(new Object()); - + public WindowingInternals windowingInternals() { + return new WindowingInternals() { @Override public StateInternals stateInternals() { return stateInternals; @@ -649,7 +654,7 @@ public class DoFnTester implements AutoCloseable { @Override public void outputWindowedValue( - OutT output, + OutputT output, Instant timestamp, Collection windows, PaneInfo pane) { @@ -658,8 +663,7 @@ public class DoFnTester implements AutoCloseable { @Override public TimerInternals timerInternals() { - throw - new UnsupportedOperationException("Timer Internals are not supported in DoFnTester"); + return timerInternals; } @Override @@ -695,12 +699,12 @@ public class DoFnTester implements AutoCloseable { } @Override - public void output(OutT output) { + public void output(OutputT output) { sideOutput(mainOutputTag, output); } @Override - public void outputWithTimestamp(OutT output, Instant timestamp) { + public void outputWithTimestamp(OutputT output, Instant timestamp) { sideOutputWithTimestamp(mainOutputTag, output, timestamp); } @@ -774,6 +778,9 @@ public class DoFnTester implements AutoCloseable { /** The outputs from the {@link DoFn} under test. */ private Map, List>> outputs; + private InMemoryStateInternals stateInternals; + private InMemoryTimerInternals timerInternals; + /** The state of processing of the {@link DoFn} under test. */ private State state = State.UNINITIALIZED; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 2443d8e..fdef908 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.transforms; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.common.collect.ImmutableList; import java.io.Serializable; import java.util.Arrays; @@ -27,6 +29,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.StringUtils; @@ -716,6 +719,8 @@ public class ParDo { @Override public PCollection apply(PCollection input) { + checkArgument( + !isSplittable(fn), "Splittable DoFn not supported by the current runner"); return PCollection.createPrimitiveOutputInternal( input.getPipeline(), input.getWindowingStrategy(), @@ -925,6 +930,9 @@ public class ParDo { @Override public PCollectionTuple apply(PCollection input) { + checkArgument( + !isSplittable(fn), "Splittable DoFn not supported by the current runner"); + PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( input.getPipeline(), TupleTagList.of(mainOutputTag).and(sideOutputTags.getAll()), @@ -997,4 +1005,15 @@ public class ParDo { .add(DisplayData.item("fn", fnClass) .withLabel("Transform Function")); } + + private static boolean isSplittable(OldDoFn oldDoFn) { + DoFn fn = DoFnAdapters.getDoFn(oldDoFn); + if (fn == null) { + return false; + } + return DoFnSignatures.INSTANCE + .getOrParseSignature(fn.getClass()) + .processElement() + .isSplittable(); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index eb6961c..9672d53 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -17,7 +17,10 @@ */ package org.apache.beam.sdk.transforms.reflect; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; /** * Interface for invoking the {@code DoFn} processing methods. @@ -43,7 +46,28 @@ public interface DoFnInvoker { * * @param c The {@link DoFn.ProcessContext} to invoke the fn with. * @param extra Factory for producing extra parameter objects (such as window), if necessary. + * @return The {@link DoFn.ProcessContinuation} returned by the underlying method, or {@link + * DoFn.ProcessContinuation#stop()} if it returns {@code void}. */ - void invokeProcessElement( + DoFn.ProcessContinuation invokeProcessElement( DoFn.ProcessContext c, DoFn.ExtraContextFactory extra); + + /** Invoke the {@link DoFn.GetInitialRestriction} method on the bound {@link DoFn}. */ + RestrictionT invokeGetInitialRestriction(InputT element); + + /** + * Invoke the {@link DoFn.GetRestrictionCoder} method on the bound {@link DoFn}. Called only + * during pipeline construction time. + */ + Coder invokeGetRestrictionCoder(CoderRegistry coderRegistry); + + /** Invoke the {@link DoFn.SplitRestriction} method on the bound {@link DoFn}. */ + void invokeSplitRestriction( + InputT element, + RestrictionT restriction, + DoFn.OutputReceiver restrictionReceiver); + + /** Invoke the {@link DoFn.NewTracker} method on the bound {@link DoFn}. */ + > TrackerT invokeNewTracker( + RestrictionT restriction); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java index da88587..fd057c3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java @@ -19,6 +19,7 @@ package org.apache.beam.sdk.transforms.reflect; import static com.google.common.base.Preconditions.checkArgument; +import com.google.common.reflect.TypeToken; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -26,6 +27,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.EnumMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; import net.bytebuddy.ByteBuddy; @@ -35,10 +37,12 @@ import net.bytebuddy.description.method.MethodDescription; import net.bytebuddy.description.modifier.FieldManifestation; import net.bytebuddy.description.modifier.Visibility; import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.description.type.TypeList; import net.bytebuddy.dynamic.DynamicType; import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; import net.bytebuddy.dynamic.scaffold.InstrumentedType; import net.bytebuddy.dynamic.scaffold.subclass.ConstructorStrategy; +import net.bytebuddy.implementation.ExceptionMethod; import net.bytebuddy.implementation.FixedValue; import net.bytebuddy.implementation.Implementation; import net.bytebuddy.implementation.Implementation.Context; @@ -48,6 +52,7 @@ import net.bytebuddy.implementation.bytecode.ByteCodeAppender; import net.bytebuddy.implementation.bytecode.StackManipulation; import net.bytebuddy.implementation.bytecode.Throw; import net.bytebuddy.implementation.bytecode.assign.Assigner; +import net.bytebuddy.implementation.bytecode.assign.TypeCasting; import net.bytebuddy.implementation.bytecode.member.FieldAccess; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; import net.bytebuddy.implementation.bytecode.member.MethodReturn; @@ -57,12 +62,17 @@ import net.bytebuddy.jar.asm.MethodVisitor; import net.bytebuddy.jar.asm.Opcodes; import net.bytebuddy.jar.asm.Type; import net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory; import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.DoFnAdapters; import org.apache.beam.sdk.transforms.OldDoFn; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.util.UserCodeException; +import org.apache.beam.sdk.values.TypeDescriptor; /** Dynamically generates {@link DoFnInvoker} instances for invoking a {@link DoFn}. */ public class DoFnInvokers { @@ -81,10 +91,10 @@ public class DoFnInvokers { private DoFnInvokers() {} /** - * Creates a {@link DoFnInvoker} for the given {@link Object}, which should be either a - * {@link DoFn} or an {@link OldDoFn}. The expected use would be to deserialize a user's - * function as an {@link Object} and then pass it to this method, so there is no need to - * statically specify what sort of object it is. + * Creates a {@link DoFnInvoker} for the given {@link Object}, which should be either a {@link + * DoFn} or an {@link OldDoFn}. The expected use would be to deserialize a user's function as an + * {@link Object} and then pass it to this method, so there is no need to statically specify what + * sort of object it is. * * @deprecated this is to be used only as a migration path for decoupling upgrades */ @@ -92,15 +102,16 @@ public class DoFnInvokers { public DoFnInvoker invokerFor(Object deserializedFn) { if (deserializedFn instanceof DoFn) { return newByteBuddyInvoker((DoFn) deserializedFn); - } else if (deserializedFn instanceof OldDoFn){ + } else if (deserializedFn instanceof OldDoFn) { return new OldDoFnInvoker<>((OldDoFn) deserializedFn); } else { - throw new IllegalArgumentException(String.format( - "Cannot create a %s for %s; it should be either a %s or an %s.", - DoFnInvoker.class.getSimpleName(), - deserializedFn.toString(), - DoFn.class.getSimpleName(), - OldDoFn.class.getSimpleName())); + throw new IllegalArgumentException( + String.format( + "Cannot create a %s for %s; it should be either a %s or an %s.", + DoFnInvoker.class.getSimpleName(), + deserializedFn.toString(), + DoFn.class.getSimpleName(), + OldDoFn.class.getSimpleName())); } } @@ -113,12 +124,13 @@ public class DoFnInvokers { } @Override - public void invokeProcessElement( + public DoFn.ProcessContinuation invokeProcessElement( DoFn.ProcessContext c, ExtraContextFactory extra) { OldDoFn.ProcessContext oldCtx = DoFnAdapters.adaptProcessContext(fn, c, extra); try { fn.processElement(oldCtx); + return DoFn.ProcessContinuation.stop(); } catch (Throwable exc) { throw UserCodeException.wrap(exc); } @@ -161,14 +173,37 @@ public class DoFnInvokers { throw UserCodeException.wrap(exc); } } + + @Override + public RestrictionT invokeGetInitialRestriction(InputT element) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public Coder invokeGetRestrictionCoder( + CoderRegistry coderRegistry) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public void invokeSplitRestriction( + InputT element, RestrictionT restriction, DoFn.OutputReceiver receiver) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } + + @Override + public > + TrackerT invokeNewTracker(RestrictionT restriction) { + throw new UnsupportedOperationException("OldDoFn is not splittable"); + } } /** @return the {@link DoFnInvoker} for the given {@link DoFn}. */ @SuppressWarnings({"unchecked", "rawtypes"}) public DoFnInvoker newByteBuddyInvoker( DoFn fn) { - return newByteBuddyInvoker(DoFnSignatures.INSTANCE.getOrParseSignature( - (Class) fn.getClass()), fn); + return newByteBuddyInvoker( + DoFnSignatures.INSTANCE.getOrParseSignature((Class) fn.getClass()), fn); } /** @return the {@link DoFnInvoker} for the given {@link DoFn}. */ @@ -214,6 +249,32 @@ public class DoFnInvokers { return constructor; } + /** Default implementation of {@link DoFn.SplitRestriction}, for delegation by bytebuddy. */ + public static class DefaultSplitRestriction { + /** Doesn't split the restriction. */ + @SuppressWarnings("unused") + public static void invokeSplitRestriction( + InputT element, RestrictionT restriction, DoFn.OutputReceiver receiver) { + receiver.output(restriction); + } + } + + /** Default implementation of {@link DoFn.GetRestrictionCoder}, for delegation by bytebuddy. */ + public static class DefaultRestrictionCoder { + private final TypeToken restrictionType; + + DefaultRestrictionCoder(TypeToken restrictionType) { + this.restrictionType = restrictionType; + } + + /** Doesn't split the restriction. */ + @SuppressWarnings({"unused", "unchecked"}) + public Coder invokeGetRestrictionCoder(CoderRegistry registry) + throws CannotProvideCoderException { + return (Coder) registry.getCoder(TypeDescriptor.of(restrictionType.getType())); + } + } + /** Generates a {@link DoFnInvoker} class for the given {@link DoFnSignature}. */ private static Class> generateInvokerClass(DoFnSignature signature) { Class> fnClass = signature.fnClass(); @@ -247,7 +308,15 @@ public class DoFnInvokers { .method(ElementMatchers.named("invokeSetup")) .intercept(delegateOrNoop(signature.setup())) .method(ElementMatchers.named("invokeTeardown")) - .intercept(delegateOrNoop(signature.teardown())); + .intercept(delegateOrNoop(signature.teardown())) + .method(ElementMatchers.named("invokeGetInitialRestriction")) + .intercept(delegateWithDowncastOrThrow(signature.getInitialRestriction())) + .method(ElementMatchers.named("invokeSplitRestriction")) + .intercept(splitRestrictionDelegation(signature)) + .method(ElementMatchers.named("invokeGetRestrictionCoder")) + .intercept(getRestrictionCoderDelegation(signature)) + .method(ElementMatchers.named("invokeNewTracker")) + .intercept(delegateWithDowncastOrThrow(signature.newTracker())); DynamicType.Unloaded unloaded = builder.make(); @@ -260,6 +329,28 @@ public class DoFnInvokers { return res; } + private static Implementation getRestrictionCoderDelegation(DoFnSignature signature) { + if (signature.processElement().isSplittable()) { + if (signature.getRestrictionCoder() == null) { + return MethodDelegation.to( + new DefaultRestrictionCoder(signature.getInitialRestriction().restrictionT())); + } else { + return new DowncastingParametersMethodDelegation( + signature.getRestrictionCoder().targetMethod()); + } + } else { + return ExceptionMethod.throwing(UnsupportedOperationException.class); + } + } + + private static Implementation splitRestrictionDelegation(DoFnSignature signature) { + if (signature.splitRestriction() == null) { + return MethodDelegation.to(DefaultSplitRestriction.class); + } else { + return new DowncastingParametersMethodDelegation(signature.splitRestriction().targetMethod()); + } + } + /** Delegates to the given method if available, or does nothing. */ private static Implementation delegateOrNoop(DoFnSignature.DoFnMethod method) { return (method == null) @@ -267,6 +358,13 @@ public class DoFnInvokers { : new DoFnMethodDelegation(method.targetMethod()); } + /** Delegates to the given method if available, or throws UnsupportedOperationException. */ + private static Implementation delegateWithDowncastOrThrow(DoFnSignature.DoFnMethod method) { + return (method == null) + ? ExceptionMethod.throwing(UnsupportedOperationException.class) + : new DowncastingParametersMethodDelegation(method.targetMethod()); + } + /** * Implements a method of {@link DoFnInvoker} (the "instrumented method") by delegating to a * "target method" of the wrapped {@link DoFn}. @@ -374,12 +472,37 @@ public class DoFnInvokers { } /** + * Passes parameters to the delegated method by downcasting each parameter of non-primitive type + * to its expected type. + */ + private static class DowncastingParametersMethodDelegation extends DoFnMethodDelegation { + DowncastingParametersMethodDelegation(Method method) { + super(method); + } + + @Override + protected StackManipulation beforeDelegation(MethodDescription instrumentedMethod) { + List pushParameters = new ArrayList<>(); + TypeList.Generic paramTypes = targetMethod.getParameters().asTypeList(); + for (int i = 0; i < paramTypes.size(); i++) { + TypeDescription.Generic paramT = paramTypes.get(i); + pushParameters.add(MethodVariableAccess.of(paramT).loadOffset(i + 1)); + if (!paramT.isPrimitive()) { + pushParameters.add(TypeCasting.to(paramT)); + } + } + return new StackManipulation.Compound(pushParameters); + } + } + + /** * Implements the invoker's {@link DoFnInvoker#invokeProcessElement} method by delegating to the * {@link DoFn.ProcessElement} method. */ private static final class ProcessElementDelegation extends DoFnMethodDelegation { private static final Map EXTRA_CONTEXT_FACTORY_METHODS; + private static final MethodDescription PROCESS_CONTINUATION_STOP_METHOD; static { try { @@ -397,11 +520,21 @@ public class DoFnInvokers { DoFnSignature.Parameter.OUTPUT_RECEIVER, new MethodDescription.ForLoadedMethod( DoFn.ExtraContextFactory.class.getMethod("outputReceiver"))); + methods.put( + DoFnSignature.Parameter.RESTRICTION_TRACKER, + new MethodDescription.ForLoadedMethod( + DoFn.ExtraContextFactory.class.getMethod("restrictionTracker"))); EXTRA_CONTEXT_FACTORY_METHODS = Collections.unmodifiableMap(methods); } catch (Exception e) { throw new RuntimeException( "Failed to locate an ExtraContextFactory method that was expected to exist", e); } + try { + PROCESS_CONTINUATION_STOP_METHOD = + new MethodDescription.ForLoadedMethod(DoFn.ProcessContinuation.class.getMethod("stop")); + } catch (NoSuchMethodException e) { + throw new RuntimeException("Failed to locate ProcessContinuation.stop()"); + } } private final DoFnSignature.ProcessElementMethod signature; @@ -427,14 +560,26 @@ public class DoFnInvokers { parameters.add( new StackManipulation.Compound( pushExtraContextFactory, - MethodInvocation.invoke(EXTRA_CONTEXT_FACTORY_METHODS.get(param)))); + MethodInvocation.invoke(EXTRA_CONTEXT_FACTORY_METHODS.get(param)), + // ExtraContextFactory.restrictionTracker() returns a RestrictionTracker, + // but the @ProcessElement method expects a concrete subtype of it. + // Insert a downcast. + (param == DoFnSignature.Parameter.RESTRICTION_TRACKER) + ? TypeCasting.to( + new TypeDescription.ForLoadedType(signature.trackerT().getRawType())) + : StackManipulation.Trivial.INSTANCE)); } return new StackManipulation.Compound(parameters); } @Override protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) { - return MethodReturn.VOID; + if (TypeDescription.VOID.equals(targetMethod.getReturnType().asErasure())) { + return new StackManipulation.Compound( + MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD), MethodReturn.REFERENCE); + } else { + return MethodReturn.returning(targetMethod.getReturnType().asErasure()); + } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 756df07..632f817 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -18,11 +18,16 @@ package org.apache.beam.sdk.transforms.reflect; import com.google.auto.value.AutoValue; +import com.google.common.reflect.TypeToken; import java.lang.reflect.Method; import java.util.Collections; import java.util.List; import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.values.PCollection; /** * Describes the signature of a {@link DoFn}, in particular, which features it uses, which extra @@ -35,6 +40,9 @@ public abstract class DoFnSignature { /** Class of the original {@link DoFn} from which this signature was produced. */ public abstract Class> fnClass(); + /** Whether this {@link DoFn} does a bounded amount of work per element. */ + public abstract PCollection.IsBounded isBoundedPerElement(); + /** Details about this {@link DoFn}'s {@link DoFn.ProcessElement} method. */ public abstract ProcessElementMethod processElement(); @@ -54,6 +62,22 @@ public abstract class DoFnSignature { @Nullable public abstract LifecycleMethod teardown(); + /** Details about this {@link DoFn}'s {@link DoFn.GetInitialRestriction} method. */ + @Nullable + public abstract GetInitialRestrictionMethod getInitialRestriction(); + + /** Details about this {@link DoFn}'s {@link DoFn.SplitRestriction} method. */ + @Nullable + public abstract SplitRestrictionMethod splitRestriction(); + + /** Details about this {@link DoFn}'s {@link DoFn.GetRestrictionCoder} method. */ + @Nullable + public abstract GetRestrictionCoderMethod getRestrictionCoder(); + + /** Details about this {@link DoFn}'s {@link DoFn.NewTracker} method. */ + @Nullable + public abstract NewTrackerMethod newTracker(); + static Builder builder() { return new AutoValue_DoFnSignature.Builder(); } @@ -61,11 +85,16 @@ public abstract class DoFnSignature { @AutoValue.Builder abstract static class Builder { abstract Builder setFnClass(Class> fnClass); + abstract Builder setIsBoundedPerElement(PCollection.IsBounded isBounded); abstract Builder setProcessElement(ProcessElementMethod processElement); abstract Builder setStartBundle(BundleMethod startBundle); abstract Builder setFinishBundle(BundleMethod finishBundle); abstract Builder setSetup(LifecycleMethod setup); abstract Builder setTeardown(LifecycleMethod teardown); + abstract Builder setGetInitialRestriction(GetInitialRestrictionMethod getInitialRestriction); + abstract Builder setSplitRestriction(SplitRestrictionMethod splitRestriction); + abstract Builder setGetRestrictionCoder(GetRestrictionCoderMethod getRestrictionCoder); + abstract Builder setNewTracker(NewTrackerMethod newTracker); abstract DoFnSignature build(); } @@ -80,6 +109,7 @@ public abstract class DoFnSignature { BOUNDED_WINDOW, INPUT_PROVIDER, OUTPUT_RECEIVER, + RESTRICTION_TRACKER } /** Describes a {@link DoFn.ProcessElement} method. */ @@ -92,17 +122,33 @@ public abstract class DoFnSignature { /** Types of optional parameters of the annotated method, in the order they appear. */ public abstract List extraParameters(); + /** Concrete type of the {@link RestrictionTracker} parameter, if present. */ + @Nullable + abstract TypeToken trackerT(); + + /** Whether this {@link DoFn} returns a {@link ProcessContinuation} or void. */ + public abstract boolean hasReturnValue(); + static ProcessElementMethod create( Method targetMethod, - List extraParameters) { + List extraParameters, + TypeToken trackerT, + boolean hasReturnValue) { return new AutoValue_DoFnSignature_ProcessElementMethod( - targetMethod, Collections.unmodifiableList(extraParameters)); + targetMethod, Collections.unmodifiableList(extraParameters), trackerT, hasReturnValue); } /** Whether this {@link DoFn} uses a Single Window. */ public boolean usesSingleWindow() { return extraParameters().contains(Parameter.BOUNDED_WINDOW); } + + /** + * Whether this {@link DoFn} is splittable. + */ + public boolean isSplittable() { + return extraParameters().contains(Parameter.RESTRICTION_TRACKER); + } } /** Describes a {@link DoFn.StartBundle} or {@link DoFn.FinishBundle} method. */ @@ -128,4 +174,68 @@ public abstract class DoFnSignature { return new AutoValue_DoFnSignature_LifecycleMethod(targetMethod); } } + + /** Describes a {@link DoFn.GetInitialRestriction} method. */ + @AutoValue + public abstract static class GetInitialRestrictionMethod implements DoFnMethod { + /** The annotated method itself. */ + @Override + public abstract Method targetMethod(); + + /** Type of the returned restriction. */ + abstract TypeToken restrictionT(); + + static GetInitialRestrictionMethod create(Method targetMethod, TypeToken restrictionT) { + return new AutoValue_DoFnSignature_GetInitialRestrictionMethod(targetMethod, restrictionT); + } + } + + /** Describes a {@link DoFn.SplitRestriction} method. */ + @AutoValue + public abstract static class SplitRestrictionMethod implements DoFnMethod { + /** The annotated method itself. */ + @Override + public abstract Method targetMethod(); + + /** Type of the restriction taken and returned. */ + abstract TypeToken restrictionT(); + + static SplitRestrictionMethod create(Method targetMethod, TypeToken restrictionT) { + return new AutoValue_DoFnSignature_SplitRestrictionMethod(targetMethod, restrictionT); + } + } + + /** Describes a {@link DoFn.NewTracker} method. */ + @AutoValue + public abstract static class NewTrackerMethod implements DoFnMethod { + /** The annotated method itself. */ + @Override + public abstract Method targetMethod(); + + /** Type of the input restriction. */ + abstract TypeToken restrictionT(); + + /** Type of the returned {@link RestrictionTracker}. */ + abstract TypeToken trackerT(); + + static NewTrackerMethod create( + Method targetMethod, TypeToken restrictionT, TypeToken trackerT) { + return new AutoValue_DoFnSignature_NewTrackerMethod(targetMethod, restrictionT, trackerT); + } + } + + /** Describes a {@link DoFn.GetRestrictionCoder} method. */ + @AutoValue + public abstract static class GetRestrictionCoderMethod implements DoFnMethod { + /** The annotated method itself. */ + @Override + public abstract Method targetMethod(); + + /** Type of the returned {@link Coder}. */ + abstract TypeToken coderT(); + + static GetRestrictionCoderMethod create(Method targetMethod, TypeToken coderT) { + return new AutoValue_DoFnSignature_GetRestrictionCoderMethod(targetMethod, coderT); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index ad15127..524ea24 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.transforms.reflect; +import static com.google.common.base.Preconditions.checkState; + import com.google.common.annotations.VisibleForTesting; import com.google.common.reflect.TypeParameter; import com.google.common.reflect.TypeToken; @@ -34,9 +36,12 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PCollection; /** * Parses a {@link DoFn} and computes its {@link DoFnSignature}. See {@link #getOrParseSignature}. @@ -88,6 +93,14 @@ public class DoFnSignatures { Method setupMethod = findAnnotatedMethod(errors, DoFn.Setup.class, fnClass, false); Method teardownMethod = findAnnotatedMethod(errors, DoFn.Teardown.class, fnClass, false); + Method getInitialRestrictionMethod = + findAnnotatedMethod(errors, DoFn.GetInitialRestriction.class, fnClass, false); + Method splitRestrictionMethod = + findAnnotatedMethod(errors, DoFn.SplitRestriction.class, fnClass, false); + Method getRestrictionCoderMethod = + findAnnotatedMethod(errors, DoFn.GetRestrictionCoder.class, fnClass, false); + Method newTrackerMethod = findAnnotatedMethod(errors, DoFn.NewTracker.class, fnClass, false); + ErrorReporter processElementErrors = errors.forMethod(DoFn.ProcessElement.class, processElementMethod); DoFnSignature.ProcessElementMethod processElement = @@ -119,7 +132,213 @@ public class DoFnSignatures { errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod)); } - return builder.build(); + DoFnSignature.GetInitialRestrictionMethod getInitialRestriction = null; + ErrorReporter getInitialRestrictionErrors = null; + if (getInitialRestrictionMethod != null) { + getInitialRestrictionErrors = + errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod); + builder.setGetInitialRestriction( + getInitialRestriction = + analyzeGetInitialRestrictionMethod( + getInitialRestrictionErrors, fnToken, getInitialRestrictionMethod, inputT)); + } + + DoFnSignature.SplitRestrictionMethod splitRestriction = null; + if (splitRestrictionMethod != null) { + ErrorReporter splitRestrictionErrors = + errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod); + builder.setSplitRestriction( + splitRestriction = + analyzeSplitRestrictionMethod( + splitRestrictionErrors, fnToken, splitRestrictionMethod, inputT)); + } + + DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = null; + if (getRestrictionCoderMethod != null) { + ErrorReporter getRestrictionCoderErrors = + errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod); + builder.setGetRestrictionCoder( + getRestrictionCoder = + analyzeGetRestrictionCoderMethod( + getRestrictionCoderErrors, fnToken, getRestrictionCoderMethod)); + } + + DoFnSignature.NewTrackerMethod newTracker = null; + if (newTrackerMethod != null) { + ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod); + builder.setNewTracker( + newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnToken, newTrackerMethod)); + } + + builder.setIsBoundedPerElement(inferBoundedness(fnToken, processElement, errors)); + + DoFnSignature signature = builder.build(); + + // Additional validation for splittable DoFn's. + if (processElement.isSplittable()) { + verifySplittableMethods(signature, errors); + } else { + verifyUnsplittableMethods(errors, signature); + } + + return signature; + } + + /** + * Infers the boundedness of the {@link DoFn.ProcessElement} method (whether or not it performs a + * bounded amount of work per element) using the following criteria: + * + *
    + *
  1. If the {@link DoFn} is not splittable, then it is bounded, it must not be annotated as + * {@link DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, and {@link + * DoFn.ProcessElement} must return {@code void}. + *
  2. If the {@link DoFn} (or any of its supertypes) is annotated as {@link + * DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, use that. Only one of + * these must be specified. + *
  3. If {@link DoFn.ProcessElement} returns {@link DoFn.ProcessContinuation}, assume it is + * unbounded. Otherwise (if it returns {@code void}), assume it is bounded. + *
  4. If {@link DoFn.ProcessElement} returns {@code void}, but the {@link DoFn} is annotated + * {@link DoFn.UnboundedPerElement}, this is an error. + *
+ */ + private static PCollection.IsBounded inferBoundedness( + TypeToken fnToken, + DoFnSignature.ProcessElementMethod processElement, + ErrorReporter errors) { + PCollection.IsBounded isBounded = null; + for (TypeToken supertype : fnToken.getTypes()) { + if (supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class) + || supertype.getRawType().isAnnotationPresent(DoFn.UnboundedPerElement.class)) { + errors.checkArgument( + isBounded == null, + "Both @%s and @%s specified", + DoFn.BoundedPerElement.class.getSimpleName(), + DoFn.UnboundedPerElement.class.getSimpleName()); + isBounded = + supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class) + ? PCollection.IsBounded.BOUNDED + : PCollection.IsBounded.UNBOUNDED; + } + } + if (processElement.isSplittable()) { + if (isBounded == null) { + isBounded = + processElement.hasReturnValue() + ? PCollection.IsBounded.UNBOUNDED + : PCollection.IsBounded.BOUNDED; + } + } else { + errors.checkArgument( + isBounded == null, + "Non-splittable, but annotated as @" + + ((isBounded == PCollection.IsBounded.BOUNDED) + ? DoFn.BoundedPerElement.class.getSimpleName() + : DoFn.UnboundedPerElement.class.getSimpleName())); + checkState(!processElement.hasReturnValue(), "Should have been inferred splittable"); + isBounded = PCollection.IsBounded.BOUNDED; + } + return isBounded; + } + + /** + * Verifies properties related to methods of splittable {@link DoFn}: + * + *
    + *
  • Must declare the required {@link DoFn.GetInitialRestriction} and {@link DoFn.NewTracker} + * methods. + *
  • Types of restrictions and trackers must match exactly between {@link DoFn.ProcessElement}, + * {@link DoFn.GetInitialRestriction}, {@link DoFn.NewTracker}, {@link + * DoFn.GetRestrictionCoder}, {@link DoFn.SplitRestriction}. + *
+ */ + private static void verifySplittableMethods(DoFnSignature signature, ErrorReporter errors) { + DoFnSignature.ProcessElementMethod processElement = signature.processElement(); + DoFnSignature.GetInitialRestrictionMethod getInitialRestriction = + signature.getInitialRestriction(); + DoFnSignature.NewTrackerMethod newTracker = signature.newTracker(); + DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = signature.getRestrictionCoder(); + DoFnSignature.SplitRestrictionMethod splitRestriction = signature.splitRestriction(); + + ErrorReporter processElementErrors = + errors.forMethod(DoFn.ProcessElement.class, processElement.targetMethod()); + + List missingRequiredMethods = new ArrayList<>(); + if (getInitialRestriction == null) { + missingRequiredMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName()); + } + if (newTracker == null) { + missingRequiredMethods.add("@" + DoFn.NewTracker.class.getSimpleName()); + } + if (!missingRequiredMethods.isEmpty()) { + processElementErrors.throwIllegalArgument( + "Splittable, but does not define the following required methods: %s", + missingRequiredMethods); + } + + processElementErrors.checkArgument( + processElement.trackerT().equals(newTracker.trackerT()), + "Has tracker type %s, but @%s method %s uses tracker type %s", + formatType(processElement.trackerT()), + DoFn.NewTracker.class.getSimpleName(), + format(newTracker.targetMethod()), + formatType(newTracker.trackerT())); + + ErrorReporter getInitialRestrictionErrors = + errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod()); + TypeToken restrictionT = getInitialRestriction.restrictionT(); + + getInitialRestrictionErrors.checkArgument( + restrictionT.equals(newTracker.restrictionT()), + "Uses restriction type %s, but @%s method %s uses restriction type %s", + formatType(restrictionT), + DoFn.NewTracker.class.getSimpleName(), + format(newTracker.targetMethod()), + formatType(newTracker.restrictionT())); + + if (getRestrictionCoder != null) { + getInitialRestrictionErrors.checkArgument( + getRestrictionCoder.coderT().isSubtypeOf(coderTypeOf(restrictionT)), + "Uses restriction type %s, but @%s method %s returns %s " + + "which is not a subtype of %s", + formatType(restrictionT), + DoFn.GetRestrictionCoder.class.getSimpleName(), + format(getRestrictionCoder.targetMethod()), + formatType(getRestrictionCoder.coderT()), + formatType(coderTypeOf(restrictionT))); + } + + if (splitRestriction != null) { + getInitialRestrictionErrors.checkArgument( + splitRestriction.restrictionT().equals(restrictionT), + "Uses restriction type %s, but @%s method %s uses restriction type %s", + formatType(restrictionT), + DoFn.SplitRestriction.class.getSimpleName(), + format(splitRestriction.targetMethod()), + formatType(splitRestriction.restrictionT())); + } + } + + /** + * Verifies that a non-splittable {@link DoFn} does not declare any methods that only make sense + * for splittable {@link DoFn}: {@link DoFn.GetInitialRestriction}, {@link DoFn.SplitRestriction}, + * {@link DoFn.NewTracker}, {@link DoFn.GetRestrictionCoder}. + */ + private static void verifyUnsplittableMethods(ErrorReporter errors, DoFnSignature signature) { + List forbiddenMethods = new ArrayList<>(); + if (signature.getInitialRestriction() != null) { + forbiddenMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName()); + } + if (signature.splitRestriction() != null) { + forbiddenMethods.add("@" + DoFn.SplitRestriction.class.getSimpleName()); + } + if (signature.newTracker() != null) { + forbiddenMethods.add("@" + DoFn.NewTracker.class.getSimpleName()); + } + if (signature.getRestrictionCoder() != null) { + forbiddenMethods.add("@" + DoFn.GetRestrictionCoder.class.getSimpleName()); + } + errors.checkArgument( + forbiddenMethods.isEmpty(), "Non-splittable, but defines methods: %s", forbiddenMethods); } /** @@ -166,7 +385,11 @@ public class DoFnSignatures { Method m, TypeToken inputT, TypeToken outputT) { - errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); + errors.checkArgument( + void.class.equals(m.getReturnType()) + || DoFn.ProcessContinuation.class.equals(m.getReturnType()), + "Must return void or %s", + DoFn.ProcessContinuation.class.getSimpleName()); TypeToken processContextToken = doFnProcessContextTypeOf(inputT, outputT); @@ -181,6 +404,7 @@ public class DoFnSignatures { formatType(processContextToken)); List extraParameters = new ArrayList<>(); + TypeToken trackerT = null; TypeToken expectedInputProviderT = inputProviderTypeOf(inputT); TypeToken expectedOutputReceiverT = outputReceiverTypeOf(outputT); @@ -190,38 +414,62 @@ public class DoFnSignatures { if (rawType.equals(BoundedWindow.class)) { errors.checkArgument( !extraParameters.contains(DoFnSignature.Parameter.BOUNDED_WINDOW), - "Multiple BoundedWindow parameters"); + "Multiple %s parameters", + BoundedWindow.class.getSimpleName()); extraParameters.add(DoFnSignature.Parameter.BOUNDED_WINDOW); } else if (rawType.equals(DoFn.InputProvider.class)) { errors.checkArgument( !extraParameters.contains(DoFnSignature.Parameter.INPUT_PROVIDER), - "Multiple InputProvider parameters"); + "Multiple %s parameters", + DoFn.InputProvider.class.getSimpleName()); errors.checkArgument( paramT.equals(expectedInputProviderT), - "Wrong type of InputProvider parameter: %s, should be %s", + "Wrong type of %s parameter: %s, should be %s", + DoFn.InputProvider.class.getSimpleName(), formatType(paramT), formatType(expectedInputProviderT)); extraParameters.add(DoFnSignature.Parameter.INPUT_PROVIDER); } else if (rawType.equals(DoFn.OutputReceiver.class)) { errors.checkArgument( !extraParameters.contains(DoFnSignature.Parameter.OUTPUT_RECEIVER), - "Multiple OutputReceiver parameters"); + "Multiple %s parameters", + DoFn.OutputReceiver.class.getSimpleName()); errors.checkArgument( paramT.equals(expectedOutputReceiverT), - "Wrong type of OutputReceiver parameter: %s, should be %s", + "Wrong type of %s parameter: %s, should be %s", + DoFn.OutputReceiver.class.getSimpleName(), formatType(paramT), formatType(expectedOutputReceiverT)); extraParameters.add(DoFnSignature.Parameter.OUTPUT_RECEIVER); + } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { + errors.checkArgument( + !extraParameters.contains(DoFnSignature.Parameter.RESTRICTION_TRACKER), + "Multiple %s parameters", + RestrictionTracker.class.getSimpleName()); + extraParameters.add(DoFnSignature.Parameter.RESTRICTION_TRACKER); + trackerT = paramT; } else { List allowedParamTypes = - Arrays.asList(formatType(new TypeToken() {})); + Arrays.asList( + formatType(new TypeToken() {}), + formatType(new TypeToken>() {})); errors.throwIllegalArgument( "%s is not a valid context parameter. Should be one of %s", formatType(paramT), allowedParamTypes); } } - return DoFnSignature.ProcessElementMethod.create(m, extraParameters); + // A splittable DoFn can not have any other extra context parameters. + if (extraParameters.contains(DoFnSignature.Parameter.RESTRICTION_TRACKER)) { + errors.checkArgument( + extraParameters.size() == 1, + "Splittable DoFn must not have any extra context arguments apart from %s, but has: %s", + trackerT, + extraParameters); + } + + return DoFnSignature.ProcessElementMethod.create( + m, extraParameters, trackerT, DoFn.ProcessContinuation.class.equals(m.getReturnType())); } @VisibleForTesting @@ -248,6 +496,100 @@ public class DoFnSignatures { return DoFnSignature.LifecycleMethod.create(m); } + @VisibleForTesting + static DoFnSignature.GetInitialRestrictionMethod analyzeGetInitialRestrictionMethod( + ErrorReporter errors, TypeToken fnToken, Method m, TypeToken inputT) { + // Method is of the form: + // @GetInitialRestriction + // RestrictionT getInitialRestriction(InputT element); + Type[] params = m.getGenericParameterTypes(); + errors.checkArgument( + params.length == 1 && fnToken.resolveType(params[0]).equals(inputT), + "Must take a single argument of type %s", + formatType(inputT)); + return DoFnSignature.GetInitialRestrictionMethod.create( + m, fnToken.resolveType(m.getGenericReturnType())); + } + + /** Generates a type token for {@code List} given {@code T}. */ + private static TypeToken> listTypeOf(TypeToken elementT) { + return new TypeToken>() {}.where(new TypeParameter() {}, elementT); + } + + @VisibleForTesting + static DoFnSignature.SplitRestrictionMethod analyzeSplitRestrictionMethod( + ErrorReporter errors, TypeToken fnToken, Method m, TypeToken inputT) { + // Method is of the form: + // @SplitRestriction + // void splitRestriction(InputT element, RestrictionT restriction); + errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); + + Type[] params = m.getGenericParameterTypes(); + errors.checkArgument(params.length == 3, "Must have exactly 3 arguments"); + errors.checkArgument( + fnToken.resolveType(params[0]).equals(inputT), + "First argument must be the element type %s", + formatType(inputT)); + + TypeToken restrictionT = fnToken.resolveType(params[1]); + TypeToken receiverT = fnToken.resolveType(params[2]); + TypeToken expectedReceiverT = outputReceiverTypeOf(restrictionT); + errors.checkArgument( + receiverT.equals(expectedReceiverT), + "Third argument must be %s, but is %s", + formatType(expectedReceiverT), + formatType(receiverT)); + + return DoFnSignature.SplitRestrictionMethod.create(m, restrictionT); + } + + /** Generates a type token for {@code Coder} given {@code T}. */ + private static TypeToken> coderTypeOf(TypeToken elementT) { + return new TypeToken>() {}.where(new TypeParameter() {}, elementT); + } + + @VisibleForTesting + static DoFnSignature.GetRestrictionCoderMethod analyzeGetRestrictionCoderMethod( + ErrorReporter errors, TypeToken fnToken, Method m) { + errors.checkArgument(m.getParameterTypes().length == 0, "Must have zero arguments"); + TypeToken resT = fnToken.resolveType(m.getGenericReturnType()); + errors.checkArgument( + resT.isSubtypeOf(TypeToken.of(Coder.class)), + "Must return a Coder, but returns %s", + formatType(resT)); + return DoFnSignature.GetRestrictionCoderMethod.create(m, resT); + } + + /** + * Generates a type token for {@code RestrictionTracker} given {@code RestrictionT}. + */ + private static + TypeToken> restrictionTrackerTypeOf( + TypeToken restrictionT) { + return new TypeToken>() {}.where( + new TypeParameter() {}, restrictionT); + } + + @VisibleForTesting + static DoFnSignature.NewTrackerMethod analyzeNewTrackerMethod( + ErrorReporter errors, TypeToken fnToken, Method m) { + // Method is of the form: + // @NewTracker + // TrackerT newTracker(RestrictionT restriction); + Type[] params = m.getGenericParameterTypes(); + errors.checkArgument(params.length == 1, "Must have a single argument"); + + TypeToken restrictionT = fnToken.resolveType(params[0]); + TypeToken trackerT = fnToken.resolveType(m.getGenericReturnType()); + TypeToken expectedTrackerT = restrictionTrackerTypeOf(restrictionT); + errors.checkArgument( + trackerT.isSubtypeOf(expectedTrackerT), + "Returns %s, but must return a subtype of %s", + formatType(trackerT), + formatType(expectedTrackerT)); + return DoFnSignature.NewTrackerMethod.create(m, restrictionT, trackerT); + } + private static Collection declaredMethodsWithAnnotation( Class anno, Class startClass, Class stopClass) { Collection matches = new ArrayList<>(); @@ -310,7 +652,7 @@ public class DoFnSignatures { } private static String format(Method method) { - return ReflectHelpers.CLASS_AND_METHOD_FORMATTER.apply(method); + return ReflectHelpers.METHOD_FORMATTER.apply(method); } private static String formatType(TypeToken t) { @@ -327,7 +669,9 @@ public class DoFnSignatures { ErrorReporter forMethod(Class annotation, Method method) { return new ErrorReporter( this, - String.format("@%s %s", annotation, (method == null) ? "(absent)" : format(method))); + String.format( + "@%s %s", + annotation.getSimpleName(), (method == null) ? "(absent)" : format(method))); } void throwIllegalArgument(String message, Object... args) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java new file mode 100644 index 0000000..6b249ee --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms.splittabledofn; + +import org.apache.beam.sdk.transforms.DoFn; + +/** + * Manages concurrent access to the restriction and keeps track of its claimed part for a AnyCoderAndData coderAndData(Coder coder, List data) { + CoderAndData coderAndData = new CoderAndData<>(); + coderAndData.coder = coder; + coderAndData.data = data; + AnyCoderAndData res = new AnyCoderAndData(); + res.coderAndData = coderAndData; + return res; + } + + private static final List TEST_DATA = + Arrays.asList( + coderAndData( + VarIntCoder.of(), Arrays.asList(-1, 0, 1, 13, Integer.MAX_VALUE, Integer.MIN_VALUE)), + coderAndData( + BigEndianLongCoder.of(), + Arrays.asList(-1L, 0L, 1L, 13L, Long.MAX_VALUE, Long.MIN_VALUE)), + coderAndData(StringUtf8Coder.of(), Arrays.asList("", "hello", "goodbye", "1")), + coderAndData( + KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()), + Arrays.asList(KV.of("", -1), KV.of("hello", 0), KV.of("goodbye", Integer.MAX_VALUE))), + coderAndData( + ListCoder.of(VarLongCoder.of()), + Arrays.asList(Arrays.asList(1L, 2L, 3L), Collections.emptyList()))); @Test + @SuppressWarnings("rawtypes") public void testDecodeEncodeEqual() throws Exception { - for (Map.Entry, Iterable> entry : TEST_DATA.entrySet()) { - // The coder and corresponding values must be the same type. - // If someone messes this up in the above test data, the test - // will fail anyhow (unless the coder magically works on data - // it does not understand). - @SuppressWarnings("unchecked") - Coder coder = (Coder) entry.getKey(); - Iterable values = entry.getValue(); - for (Object value : values) { - CoderProperties.coderDecodeEncodeEqual(coder, value); + for (AnyCoderAndData keyCoderAndData : TEST_DATA) { + Coder keyCoder = keyCoderAndData.coderAndData.coder; + for (Object key : keyCoderAndData.coderAndData.data) { + for (AnyCoderAndData valueCoderAndData : TEST_DATA) { + Coder valueCoder = valueCoderAndData.coderAndData.coder; + for (Object value : valueCoderAndData.coderAndData.data) { + CoderProperties.coderDecodeEncodeEqual( + KvCoder.of(keyCoder, valueCoder), KV.of(key, value)); + } + } } } } @@ -75,37 +88,29 @@ public class KvCoderTest { @Test public void testEncodingId() throws Exception { CoderProperties.coderHasEncodingId( - KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), - EXPECTED_ENCODING_ID); + KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), EXPECTED_ENCODING_ID); } - /** - * Homogeneously typed test value for ease of use with the wire format test utility. - */ + /** Homogeneously typed test value for ease of use with the wire format test utility. */ private static final Coder> TEST_CODER = KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()); - private static final List> TEST_VALUES = Arrays.asList( - KV.of("", -1), - KV.of("hello", 0), - KV.of("goodbye", Integer.MAX_VALUE)); + private static final List> TEST_VALUES = + Arrays.asList(KV.of("", -1), KV.of("hello", 0), KV.of("goodbye", Integer.MAX_VALUE)); /** - * Generated data to check that the wire format has not changed. To regenerate, see - * {@link org.apache.beam.sdk.coders.PrintBase64Encodings}. + * Generated data to check that the wire format has not changed. To regenerate, see {@link + * org.apache.beam.sdk.coders.PrintBase64Encodings}. */ - private static final List TEST_ENCODINGS = Arrays.asList( - "AP____8P", - "BWhlbGxvAA", - "B2dvb2RieWX_____Bw"); + private static final List TEST_ENCODINGS = + Arrays.asList("AP____8P", "BWhlbGxvAA", "B2dvb2RieWX_____Bw"); @Test public void testWireFormatEncode() throws Exception { CoderProperties.coderEncodesBase64(TEST_CODER, TEST_VALUES, TEST_ENCODINGS); } - @Rule - public ExpectedException thrown = ExpectedException.none(); + @Rule public ExpectedException thrown = ExpectedException.none(); @Test public void encodeNullThrowsCoderException() throws Exception { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a0a24883/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 7ce98bc..9c7b991 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -56,6 +56,7 @@ import org.apache.beam.sdk.transforms.ParDo.Bound; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; import org.apache.beam.sdk.transforms.display.DisplayDataMatchers; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; @@ -1469,4 +1470,49 @@ public class ParDoTest implements Serializable { assertThat(displayData, includesDisplayDataFrom(fn)); assertThat(displayData, hasDisplayItem("fn", fn.getClass())); } + + private abstract static class SomeTracker implements RestrictionTracker {} + private static class TestSplittableDoFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext context, SomeTracker tracker) {} + + @GetInitialRestriction + public Object getInitialRestriction(Integer element) { + return null; + } + + @NewTracker + public SomeTracker newTracker(Object restriction) { + return null; + } + } + + @Test + public void testRejectsSplittableDoFnByDefault() { + // ParDo with a splittable DoFn must be overridden by the runner. + // Without an override, applying it directly must fail. + Pipeline p = TestPipeline.create(); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Splittable DoFn not supported by the current runner"); + + p.apply(Create.of(1, 2, 3)).apply(ParDo.of(new TestSplittableDoFn())); + } + + @Test + public void testMultiRejectsSplittableDoFnByDefault() { + // ParDo with a splittable DoFn must be overridden by the runner. + // Without an override, applying it directly must fail. + Pipeline p = TestPipeline.create(); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("Splittable DoFn not supported by the current runner"); + + p.apply(Create.of(1, 2, 3)) + .apply( + ParDo.of(new TestSplittableDoFn()) + .withOutputTags( + new TupleTag("main") {}, + TupleTagList.of(new TupleTag("side1") {}))); + } }