From commits-return-114551-archive-asf-public=cust-asf.ponee.io@beam.apache.org Fri Aug 6 17:04:41 2021 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mxout1-ec2-va.apache.org (mxout1-ec2-va.apache.org [3.227.148.255]) by mx-eu-01.ponee.io (Postfix) with ESMTPS id 86B77180181 for ; Fri, 6 Aug 2021 19:04:41 +0200 (CEST) Received: from mail.apache.org (mailroute1-lw-us.apache.org [207.244.88.153]) by mxout1-ec2-va.apache.org (ASF Mail Server at mxout1-ec2-va.apache.org) with SMTP id C61E944947 for ; Fri, 6 Aug 2021 17:04:40 +0000 (UTC) Received: (qmail 46112 invoked by uid 500); 6 Aug 2021 17:04:40 -0000 Mailing-List: contact commits-help@beam.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@beam.apache.org Delivered-To: mailing list commits@beam.apache.org Received: (qmail 46103 invoked by uid 99); 6 Aug 2021 17:04:40 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 06 Aug 2021 17:04:40 +0000 Received: by gitbox.apache.org (ASF Mail Server at gitbox.apache.org, from userid 33) id 39B6B81F23; Fri, 6 Aug 2021 17:04:40 +0000 (UTC) Date: Fri, 06 Aug 2021 17:04:35 +0000 To: "commits@beam.apache.org" Subject: [beam] branch master updated: [BEAM-10212] Integrate caching client (#15214) MIME-Version: 1.0 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit Message-ID: <162826947002.28800.7117979410923206995@gitbox.apache.org> From: amaliujia@apache.org X-Git-Host: gitbox.apache.org X-Git-Repo: beam X-Git-Refname: refs/heads/master X-Git-Reftype: branch X-Git-Oldrev: 3537f7ed430de9a85dd76b7a7db51fc67024db12 X-Git-Newrev: 9c9903d50b59a6ca956b9d43809dc26c490cb849 X-Git-Rev: 9c9903d50b59a6ca956b9d43809dc26c490cb849 X-Git-NotificationType: ref_changed_plus_diff X-Git-Multimail-Version: 1.5.dev Auto-Submitted: auto-generated This is an automated email from the ASF dual-hosted git repository. amaliujia pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git The following commit(s) were added to refs/heads/master by this push: new 9c9903d [BEAM-10212] Integrate caching client (#15214) 9c9903d is described below commit 9c9903d50b59a6ca956b9d43809dc26c490cb849 Author: anthonyqzhu <43458232+anthonyqzhu@users.noreply.github.com> AuthorDate: Fri Aug 6 13:03:28 2021 -0400 [BEAM-10212] Integrate caching client (#15214) * [BEAM-10212] Add state cache to ProcessBundleHandler --- .../fnexecution/control/RemoteExecutionTest.java | 407 +++++++++++++++++++++ .../fn/harness/control/ProcessBundleHandler.java | 52 ++- 2 files changed, 451 insertions(+), 8 deletions(-) diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 5acb87d..f238601 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -42,6 +42,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; @@ -55,6 +56,7 @@ import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.function.Function; import org.apache.beam.fn.harness.FnHarness; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse; @@ -68,6 +70,7 @@ import org.apache.beam.runners.core.construction.graph.FusedPipeline; import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser; import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode; import org.apache.beam.runners.core.construction.graph.ProtoOverrides; +import org.apache.beam.runners.core.construction.graph.SideInputReference; import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander; import org.apache.beam.runners.core.metrics.DistributionData; import org.apache.beam.runners.core.metrics.ExecutionStateSampler; @@ -532,6 +535,160 @@ public class RemoteExecutionTest implements Serializable { } } + @Test + public void testExecutionWithSideInputCaching() throws Exception { + Pipeline p = Pipeline.create(); + addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api"); + // TODO(BEAM-10097): Remove experiment once all portable runners support this view type + addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2"); + PCollection input = + p.apply("impulse", Impulse.create()) + .apply( + "create", + ParDo.of( + new DoFn() { + @ProcessElement + public void process(ProcessContext ctxt) { + ctxt.output("zero"); + ctxt.output("one"); + ctxt.output("two"); + } + })) + .setCoder(StringUtf8Coder.of()); + PCollectionView> view = input.apply("createSideInput", View.asIterable()); + + input + .apply( + "readSideInput", + ParDo.of( + new DoFn>() { + @ProcessElement + public void processElement(ProcessContext context) { + for (String value : context.sideInput(view)) { + context.output(KV.of(context.element(), value)); + } + } + }) + .withSideInputs(view)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())) + // Force the output to be materialized + .apply("gbk", GroupByKey.create()); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto); + Optional optionalStage = + Iterables.tryFind( + fused.getFusedStages(), (ExecutableStage stage) -> !stage.getSideInputs().isEmpty()); + checkState(optionalStage.isPresent(), "Expected a stage with side inputs."); + ExecutableStage stage = optionalStage.get(); + + ExecutableProcessBundleDescriptor descriptor = + ProcessBundleDescriptors.fromExecutableStage( + "test_stage", + stage, + dataServer.getApiServiceDescriptor(), + stateServer.getApiServiceDescriptor()); + + BundleProcessor processor = + controlClient.getProcessor( + descriptor.getProcessBundleDescriptor(), + descriptor.getRemoteInputDestinations(), + stateDelegator); + Map remoteOutputCoders = descriptor.getRemoteOutputCoders(); + Map>> outputValues = new HashMap<>(); + Map> outputReceivers = new HashMap<>(); + for (Entry remoteOutputCoder : remoteOutputCoders.entrySet()) { + List> outputContents = Collections.synchronizedList(new ArrayList<>()); + outputValues.put(remoteOutputCoder.getKey(), outputContents); + outputReceivers.put( + remoteOutputCoder.getKey(), + RemoteOutputReceiver.of( + (Coder>) remoteOutputCoder.getValue(), outputContents::add)); + } + + StoringStateRequestHandler stateRequestHandler = + new StoringStateRequestHandler( + StateRequestHandlers.forSideInputHandlerFactory( + descriptor.getSideInputSpecs(), + new SideInputHandlerFactory() { + @Override + public + IterableSideInputHandler forIterableSideInput( + String pTransformId, + String sideInputId, + Coder elementCoder, + Coder windowCoder) { + return new IterableSideInputHandler() { + @Override + public Iterable get(W window) { + return (Iterable) Arrays.asList("A", "B", "C"); + } + + @Override + public Coder elementCoder() { + return elementCoder; + } + }; + } + + @Override + public + MultimapSideInputHandler forMultimapSideInput( + String pTransformId, + String sideInputId, + KvCoder elementCoder, + Coder windowCoder) { + throw new UnsupportedOperationException(); + } + })); + SideInputReference sideInputReference = stage.getSideInputs().iterator().next(); + String transformId = sideInputReference.transform().getId(); + String sideInputId = sideInputReference.localName(); + stateRequestHandler.addCacheToken( + BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder() + .setSideInput( + BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder() + .setSideInputId(sideInputId) + .setTransformId(transformId) + .build()) + .setToken(ByteString.copyFromUtf8("SideInputToken")) + .build()); + BundleProgressHandler progressHandler = BundleProgressHandler.ignored(); + + try (RemoteBundle bundle = + processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) { + Iterables.getOnlyElement(bundle.getInputReceivers().values()) + .accept(valueInGlobalWindow("X")); + } + + try (RemoteBundle bundle = + processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) { + Iterables.getOnlyElement(bundle.getInputReceivers().values()) + .accept(valueInGlobalWindow("X")); + } + for (Collection> windowedValues : outputValues.values()) { + assertThat( + windowedValues, + containsInAnyOrder( + valueInGlobalWindow(KV.of("X", "A")), + valueInGlobalWindow(KV.of("X", "B")), + valueInGlobalWindow(KV.of("X", "C")), + valueInGlobalWindow(KV.of("X", "A")), + valueInGlobalWindow(KV.of("X", "B")), + valueInGlobalWindow(KV.of("X", "C")))); + } + + // Only expect one read to the sideInput + assertEquals(1, stateRequestHandler.receivedRequests.size()); + BeamFnApi.StateRequest receivedRequest = stateRequestHandler.receivedRequests.get(0); + assertEquals( + receivedRequest.getStateKey().getIterableSideInput(), + BeamFnApi.StateKey.IterableSideInput.newBuilder() + .setSideInputId(sideInputId) + .setTransformId(transformId) + .build()); + } + /** * A {@link DoFn} that uses static maps of {@link CountDownLatch}es to block execution allowing * for synchronization during test execution. The expected flow is: @@ -1041,6 +1198,256 @@ public class RemoteExecutionTest implements Serializable { } @Test + public void testExecutionWithUserStateCaching() throws Exception { + Pipeline p = Pipeline.create(); + final String stateId = "foo"; + final String stateId2 = "bar"; + + p.apply("impulse", Impulse.create()) + .apply( + "create", + ParDo.of( + new DoFn>() { + @ProcessElement + public void process(ProcessContext ctxt) {} + })) + .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())) + .apply( + "userState", + ParDo.of( + new DoFn, KV>() { + + @StateId(stateId) + private final StateSpec> bufferState = + StateSpecs.bag(StringUtf8Coder.of()); + + @StateId(stateId2) + private final StateSpec> bufferState2 = + StateSpecs.bag(StringUtf8Coder.of()); + + @ProcessElement + public void processElement( + @Element KV element, + @StateId(stateId) BagState state, + @StateId(stateId2) BagState state2, + OutputReceiver> r) { + for (String value : state.read()) { + r.output(KV.of(element.getKey(), value)); + } + ReadableState isEmpty = state2.isEmpty(); + if (isEmpty.read()) { + r.output(KV.of(element.getKey(), "Empty")); + } else { + state2.clear(); + } + } + })) + // Force the output to be materialized + .apply("gbk", GroupByKey.create()); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto); + Optional optionalStage = + Iterables.tryFind( + fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty()); + checkState(optionalStage.isPresent(), "Expected a stage with user state."); + ExecutableStage stage = optionalStage.get(); + + ExecutableProcessBundleDescriptor descriptor = + ProcessBundleDescriptors.fromExecutableStage( + "test_stage", + stage, + dataServer.getApiServiceDescriptor(), + stateServer.getApiServiceDescriptor()); + + BundleProcessor processor = + controlClient.getProcessor( + descriptor.getProcessBundleDescriptor(), + descriptor.getRemoteInputDestinations(), + stateDelegator); + Map remoteOutputCoders = descriptor.getRemoteOutputCoders(); + Map>> outputValues = new HashMap<>(); + Map> outputReceivers = new HashMap<>(); + for (Entry remoteOutputCoder : remoteOutputCoders.entrySet()) { + List> outputContents = Collections.synchronizedList(new ArrayList<>()); + outputValues.put(remoteOutputCoder.getKey(), outputContents); + outputReceivers.put( + remoteOutputCoder.getKey(), + RemoteOutputReceiver.of( + (Coder>) remoteOutputCoder.getValue(), outputContents::add)); + } + + Map> userStateData = + ImmutableMap.of( + stateId, + new ArrayList( + Arrays.asList( + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "A", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "B", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), + stateId2, + new ArrayList( + Arrays.asList( + ByteString.copyFrom( + CoderUtils.encodeToByteArray( + StringUtf8Coder.of(), "D", Coder.Context.NESTED))))); + + StoringStateRequestHandler stateRequestHandler = + new StoringStateRequestHandler( + StateRequestHandlers.forBagUserStateHandlerFactory( + descriptor, + new BagUserStateHandlerFactory() { + @Override + public BagUserStateHandler forUserState( + String pTransformId, + String userStateId, + Coder keyCoder, + Coder valueCoder, + Coder windowCoder) { + return new BagUserStateHandler() { + @Override + public Iterable get(ByteString key, BoundedWindow window) { + return (Iterable) userStateData.get(userStateId); + } + + @Override + public void append( + ByteString key, BoundedWindow window, Iterator values) { + Iterators.addAll(userStateData.get(userStateId), (Iterator) values); + } + + @Override + public void clear(ByteString key, BoundedWindow window) { + userStateData.get(userStateId).clear(); + } + }; + } + })); + + try (RemoteBundle bundle = + processor.newBundle( + outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) { + Iterables.getOnlyElement(bundle.getInputReceivers().values()) + .accept(valueInGlobalWindow(KV.of("X", "Y"))); + } + try (RemoteBundle bundle2 = + processor.newBundle( + outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) { + Iterables.getOnlyElement(bundle2.getInputReceivers().values()) + .accept(valueInGlobalWindow(KV.of("X", "Z"))); + } + for (Collection> windowedValues : outputValues.values()) { + assertThat( + windowedValues, + containsInAnyOrder( + valueInGlobalWindow(KV.of("X", "A")), + valueInGlobalWindow(KV.of("X", "B")), + valueInGlobalWindow(KV.of("X", "C")), + valueInGlobalWindow(KV.of("X", "A")), + valueInGlobalWindow(KV.of("X", "B")), + valueInGlobalWindow(KV.of("X", "C")), + valueInGlobalWindow(KV.of("X", "Empty")))); + } + assertThat( + userStateData.get(stateId), + IsIterableContainingInOrder.contains( + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), + ByteString.copyFrom( + CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))); + assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable()); + + // 3 Requests expected: state read, state2 read, and state2 clear + assertEquals(3, stateRequestHandler.getRequestCount()); + ByteString.Output out = ByteString.newOutput(); + StringUtf8Coder.of().encode("X", out); + + assertEquals( + stateId, + stateRequestHandler + .receivedRequests + .get(0) + .getStateKey() + .getBagUserState() + .getUserStateId()); + assertEquals( + stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), + out.toByteString()); + assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet()); + + assertEquals( + stateId2, + stateRequestHandler + .receivedRequests + .get(1) + .getStateKey() + .getBagUserState() + .getUserStateId()); + assertEquals( + stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), + out.toByteString()); + assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet()); + + assertEquals( + stateId2, + stateRequestHandler + .receivedRequests + .get(2) + .getStateKey() + .getBagUserState() + .getUserStateId()); + assertEquals( + stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), + out.toByteString()); + assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear()); + } + + /** + * A state handler that stores each state request made - used to validate that cached requests are + * not forwarded to the state client. + */ + private static class StoringStateRequestHandler implements StateRequestHandler { + + private StateRequestHandler stateRequestHandler; + private ArrayList receivedRequests; + private ArrayList cacheTokens; + + StoringStateRequestHandler(StateRequestHandler delegate) { + stateRequestHandler = delegate; + receivedRequests = new ArrayList<>(); + cacheTokens = new ArrayList<>(); + } + + @Override + public CompletionStage handle(BeamFnApi.StateRequest request) + throws Exception { + receivedRequests.add(request); + return stateRequestHandler.handle(request); + } + + @Override + public Iterable getCacheTokens() { + return Iterables.concat(stateRequestHandler.getCacheTokens(), cacheTokens); + } + + public int getRequestCount() { + return receivedRequests.size(); + } + + public void addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken token) { + cacheTokens.add(token); + } + } + + @Test public void testExecutionWithTimer() throws Exception { Pipeline p = Pipeline.create(); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index b40c9d5..e1d6ad4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -50,6 +51,7 @@ import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient; import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC; import org.apache.beam.fn.harness.state.BeamFnStateClient; import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; +import org.apache.beam.fn.harness.state.CachingBeamFnStateClient; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest; @@ -74,6 +76,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; @@ -140,10 +143,29 @@ public class ProcessBundleHandler { REGISTERED_RUNNER_FACTORIES = builder.build(); } + // Creates a new map of state data for newly encountered state keys + private CacheLoader< + BeamFnApi.StateKey, + Map> + stateKeyMapCacheLoader = + new CacheLoader< + BeamFnApi.StateKey, + Map>() { + @Override + public Map load( + BeamFnApi.StateKey key) { + return new HashMap<>(); + } + }; + private final PipelineOptions options; private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; + private final LoadingCache< + BeamFnApi.StateKey, + Map> + stateCache; private final FinalizeBundleHandler finalizeBundleHandler; private final ShortIdMap shortIds; private final boolean runnerAcceptsShortIds; @@ -186,6 +208,7 @@ public class ProcessBundleHandler { this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; this.beamFnStateGrpcClientCache = beamFnStateGrpcClientCache; + this.stateCache = CacheBuilder.newBuilder().build(stateKeyMapCacheLoader); this.finalizeBundleHandler = finalizeBundleHandler; this.shortIds = shortIds; this.runnerAcceptsShortIds = @@ -491,14 +514,27 @@ public class ProcessBundleHandler { } } - // Instantiate a State API call handler depending on whether a State ApiServiceDescriptor - // was specified. - HandleStateCallsForBundle beamFnStateClient = - bundleDescriptor.hasStateApiServiceDescriptor() - ? new BlockTillStateCallsFinish( - beamFnStateGrpcClientCache.forApiServiceDescriptor( - bundleDescriptor.getStateApiServiceDescriptor())) - : new FailAllStateCallsForBundle(processBundleRequest); + // Instantiate a State API call handler depending on whether a State ApiServiceDescriptor was + // specified. + HandleStateCallsForBundle beamFnStateClient; + if (bundleDescriptor.hasStateApiServiceDescriptor()) { + BeamFnStateClient underlyingClient = + beamFnStateGrpcClientCache.forApiServiceDescriptor( + bundleDescriptor.getStateApiServiceDescriptor()); + + // If pipeline is batch, use a CachingBeamFnStateClient to store state responses. + // Once streaming is supported, always use CachingBeamFnStateClient as the arg + // to BlockTillStateCallsFinish + beamFnStateClient = + new BlockTillStateCallsFinish( + options.as(StreamingOptions.class).isStreaming() + ? underlyingClient + : new CachingBeamFnStateClient( + underlyingClient, stateCache, processBundleRequest.getCacheTokensList())); + + } else { + beamFnStateClient = new FailAllStateCallsForBundle(processBundleRequest); + } // Instantiate a Timer client registration handler depending on whether a Timer // ApiServiceDescriptor was specified.