beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From amaliu...@apache.org
Subject [beam] branch master updated: [BEAM-10212] Integrate caching client (#15214)
Date Fri, 06 Aug 2021 17:04:35 GMT
This is an automated email from the ASF dual-hosted git repository.

amaliujia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c9903d  [BEAM-10212] Integrate caching client (#15214)
9c9903d is described below

commit 9c9903d50b59a6ca956b9d43809dc26c490cb849
Author: anthonyqzhu <43458232+anthonyqzhu@users.noreply.github.com>
AuthorDate: Fri Aug 6 13:03:28 2021 -0400

    [BEAM-10212] Integrate caching client (#15214)
    
    * [BEAM-10212] Add state cache to ProcessBundleHandler
---
 .../fnexecution/control/RemoteExecutionTest.java   | 407 +++++++++++++++++++++
 .../fn/harness/control/ProcessBundleHandler.java   |  52 ++-
 2 files changed, 451 insertions(+), 8 deletions(-)

diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 5acb87d..f238601 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -42,6 +42,7 @@ import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 import java.util.UUID;
+import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.CountDownLatch;
@@ -55,6 +56,7 @@ import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 import org.apache.beam.fn.harness.FnHarness;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse;
@@ -68,6 +70,7 @@ import org.apache.beam.runners.core.construction.graph.FusedPipeline;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
 import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SideInputReference;
 import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
 import org.apache.beam.runners.core.metrics.DistributionData;
 import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
@@ -532,6 +535,160 @@ public class RemoteExecutionTest implements Serializable {
     }
   }
 
+  @Test
+  public void testExecutionWithSideInputCaching() throws Exception {
+    Pipeline p = Pipeline.create();
+    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
+    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
+    PCollection<String> input =
+        p.apply("impulse", Impulse.create())
+            .apply(
+                "create",
+                ParDo.of(
+                    new DoFn<byte[], String>() {
+                      @ProcessElement
+                      public void process(ProcessContext ctxt) {
+                        ctxt.output("zero");
+                        ctxt.output("one");
+                        ctxt.output("two");
+                      }
+                    }))
+            .setCoder(StringUtf8Coder.of());
+    PCollectionView<Iterable<String>> view = input.apply("createSideInput", View.asIterable());
+
+    input
+        .apply(
+            "readSideInput",
+            ParDo.of(
+                    new DoFn<String, KV<String, String>>() {
+                      @ProcessElement
+                      public void processElement(ProcessContext context) {
+                        for (String value : context.sideInput(view)) {
+                          context.output(KV.of(context.element(), value));
+                        }
+                      }
+                    })
+                .withSideInputs(view))
+        .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+        // Force the output to be materialized
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(
+            fused.getFusedStages(), (ExecutableStage stage) -> !stage.getSideInputs().isEmpty());
+    checkState(optionalStage.isPresent(), "Expected a stage with side inputs.");
+    ExecutableStage stage = optionalStage.get();
+
+    ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "test_stage",
+            stage,
+            dataServer.getApiServiceDescriptor(),
+            stateServer.getApiServiceDescriptor());
+
+    BundleProcessor processor =
+        controlClient.getProcessor(
+            descriptor.getProcessBundleDescriptor(),
+            descriptor.getRemoteInputDestinations(),
+            stateDelegator);
+    Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
+    Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
+    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+    for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
+      List<WindowedValue<?>> outputContents = Collections.synchronizedList(new
ArrayList<>());
+      outputValues.put(remoteOutputCoder.getKey(), outputContents);
+      outputReceivers.put(
+          remoteOutputCoder.getKey(),
+          RemoteOutputReceiver.of(
+              (Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
+    }
+
+    StoringStateRequestHandler stateRequestHandler =
+        new StoringStateRequestHandler(
+            StateRequestHandlers.forSideInputHandlerFactory(
+                descriptor.getSideInputSpecs(),
+                new SideInputHandlerFactory() {
+                  @Override
+                  public <V, W extends BoundedWindow>
+                      IterableSideInputHandler<V, W> forIterableSideInput(
+                          String pTransformId,
+                          String sideInputId,
+                          Coder<V> elementCoder,
+                          Coder<W> windowCoder) {
+                    return new IterableSideInputHandler<V, W>() {
+                      @Override
+                      public Iterable<V> get(W window) {
+                        return (Iterable) Arrays.asList("A", "B", "C");
+                      }
+
+                      @Override
+                      public Coder<V> elementCoder() {
+                        return elementCoder;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public <K, V, W extends BoundedWindow>
+                      MultimapSideInputHandler<K, V, W> forMultimapSideInput(
+                          String pTransformId,
+                          String sideInputId,
+                          KvCoder<K, V> elementCoder,
+                          Coder<W> windowCoder) {
+                    throw new UnsupportedOperationException();
+                  }
+                }));
+    SideInputReference sideInputReference = stage.getSideInputs().iterator().next();
+    String transformId = sideInputReference.transform().getId();
+    String sideInputId = sideInputReference.localName();
+    stateRequestHandler.addCacheToken(
+        BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder()
+            .setSideInput(
+                BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder()
+                    .setSideInputId(sideInputId)
+                    .setTransformId(transformId)
+                    .build())
+            .setToken(ByteString.copyFromUtf8("SideInputToken"))
+            .build());
+    BundleProgressHandler progressHandler = BundleProgressHandler.ignored();
+
+    try (RemoteBundle bundle =
+        processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(valueInGlobalWindow("X"));
+    }
+
+    try (RemoteBundle bundle =
+        processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(valueInGlobalWindow("X"));
+    }
+    for (Collection<WindowedValue<?>> windowedValues : outputValues.values())
{
+      assertThat(
+          windowedValues,
+          containsInAnyOrder(
+              valueInGlobalWindow(KV.of("X", "A")),
+              valueInGlobalWindow(KV.of("X", "B")),
+              valueInGlobalWindow(KV.of("X", "C")),
+              valueInGlobalWindow(KV.of("X", "A")),
+              valueInGlobalWindow(KV.of("X", "B")),
+              valueInGlobalWindow(KV.of("X", "C"))));
+    }
+
+    // Only expect one read to the sideInput
+    assertEquals(1, stateRequestHandler.receivedRequests.size());
+    BeamFnApi.StateRequest receivedRequest = stateRequestHandler.receivedRequests.get(0);
+    assertEquals(
+        receivedRequest.getStateKey().getIterableSideInput(),
+        BeamFnApi.StateKey.IterableSideInput.newBuilder()
+            .setSideInputId(sideInputId)
+            .setTransformId(transformId)
+            .build());
+  }
+
   /**
    * A {@link DoFn} that uses static maps of {@link CountDownLatch}es to block execution
allowing
    * for synchronization during test execution. The expected flow is:
@@ -1041,6 +1198,256 @@ public class RemoteExecutionTest implements Serializable {
   }
 
   @Test
+  public void testExecutionWithUserStateCaching() throws Exception {
+    Pipeline p = Pipeline.create();
+    final String stateId = "foo";
+    final String stateId2 = "bar";
+
+    p.apply("impulse", Impulse.create())
+        .apply(
+            "create",
+            ParDo.of(
+                new DoFn<byte[], KV<String, String>>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {}
+                }))
+        .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+        .apply(
+            "userState",
+            ParDo.of(
+                new DoFn<KV<String, String>, KV<String, String>>() {
+
+                  @StateId(stateId)
+                  private final StateSpec<BagState<String>> bufferState =
+                      StateSpecs.bag(StringUtf8Coder.of());
+
+                  @StateId(stateId2)
+                  private final StateSpec<BagState<String>> bufferState2 =
+                      StateSpecs.bag(StringUtf8Coder.of());
+
+                  @ProcessElement
+                  public void processElement(
+                      @Element KV<String, String> element,
+                      @StateId(stateId) BagState<String> state,
+                      @StateId(stateId2) BagState<String> state2,
+                      OutputReceiver<KV<String, String>> r) {
+                    for (String value : state.read()) {
+                      r.output(KV.of(element.getKey(), value));
+                    }
+                    ReadableState<Boolean> isEmpty = state2.isEmpty();
+                    if (isEmpty.read()) {
+                      r.output(KV.of(element.getKey(), "Empty"));
+                    } else {
+                      state2.clear();
+                    }
+                  }
+                }))
+        // Force the output to be materialized
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(
+            fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
+    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+    ExecutableStage stage = optionalStage.get();
+
+    ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "test_stage",
+            stage,
+            dataServer.getApiServiceDescriptor(),
+            stateServer.getApiServiceDescriptor());
+
+    BundleProcessor processor =
+        controlClient.getProcessor(
+            descriptor.getProcessBundleDescriptor(),
+            descriptor.getRemoteInputDestinations(),
+            stateDelegator);
+    Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
+    Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
+    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+    for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
+      List<WindowedValue<?>> outputContents = Collections.synchronizedList(new
ArrayList<>());
+      outputValues.put(remoteOutputCoder.getKey(), outputContents);
+      outputReceivers.put(
+          remoteOutputCoder.getKey(),
+          RemoteOutputReceiver.of(
+              (Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
+    }
+
+    Map<String, List<ByteString>> userStateData =
+        ImmutableMap.of(
+            stateId,
+            new ArrayList(
+                Arrays.asList(
+                    ByteString.copyFrom(
+                        CoderUtils.encodeToByteArray(
+                            StringUtf8Coder.of(), "A", Coder.Context.NESTED)),
+                    ByteString.copyFrom(
+                        CoderUtils.encodeToByteArray(
+                            StringUtf8Coder.of(), "B", Coder.Context.NESTED)),
+                    ByteString.copyFrom(
+                        CoderUtils.encodeToByteArray(
+                            StringUtf8Coder.of(), "C", Coder.Context.NESTED)))),
+            stateId2,
+            new ArrayList(
+                Arrays.asList(
+                    ByteString.copyFrom(
+                        CoderUtils.encodeToByteArray(
+                            StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
+
+    StoringStateRequestHandler stateRequestHandler =
+        new StoringStateRequestHandler(
+            StateRequestHandlers.forBagUserStateHandlerFactory(
+                descriptor,
+                new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>()
{
+                  @Override
+                  public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(
+                      String pTransformId,
+                      String userStateId,
+                      Coder<ByteString> keyCoder,
+                      Coder<Object> valueCoder,
+                      Coder<BoundedWindow> windowCoder) {
+                    return new BagUserStateHandler<ByteString, Object, BoundedWindow>()
{
+                      @Override
+                      public Iterable<Object> get(ByteString key, BoundedWindow window)
{
+                        return (Iterable) userStateData.get(userStateId);
+                      }
+
+                      @Override
+                      public void append(
+                          ByteString key, BoundedWindow window, Iterator<Object> values)
{
+                        Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
+                      }
+
+                      @Override
+                      public void clear(ByteString key, BoundedWindow window) {
+                        userStateData.get(userStateId).clear();
+                      }
+                    };
+                  }
+                }));
+
+    try (RemoteBundle bundle =
+        processor.newBundle(
+            outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(valueInGlobalWindow(KV.of("X", "Y")));
+    }
+    try (RemoteBundle bundle2 =
+        processor.newBundle(
+            outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
+      Iterables.getOnlyElement(bundle2.getInputReceivers().values())
+          .accept(valueInGlobalWindow(KV.of("X", "Z")));
+    }
+    for (Collection<WindowedValue<?>> windowedValues : outputValues.values())
{
+      assertThat(
+          windowedValues,
+          containsInAnyOrder(
+              valueInGlobalWindow(KV.of("X", "A")),
+              valueInGlobalWindow(KV.of("X", "B")),
+              valueInGlobalWindow(KV.of("X", "C")),
+              valueInGlobalWindow(KV.of("X", "A")),
+              valueInGlobalWindow(KV.of("X", "B")),
+              valueInGlobalWindow(KV.of("X", "C")),
+              valueInGlobalWindow(KV.of("X", "Empty"))));
+    }
+    assertThat(
+        userStateData.get(stateId),
+        IsIterableContainingInOrder.contains(
+            ByteString.copyFrom(
+                CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)),
+            ByteString.copyFrom(
+                CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)),
+            ByteString.copyFrom(
+                CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
+    assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
+
+    // 3 Requests expected: state read, state2 read, and state2 clear
+    assertEquals(3, stateRequestHandler.getRequestCount());
+    ByteString.Output out = ByteString.newOutput();
+    StringUtf8Coder.of().encode("X", out);
+
+    assertEquals(
+        stateId,
+        stateRequestHandler
+            .receivedRequests
+            .get(0)
+            .getStateKey()
+            .getBagUserState()
+            .getUserStateId());
+    assertEquals(
+        stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(),
+        out.toByteString());
+    assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
+
+    assertEquals(
+        stateId2,
+        stateRequestHandler
+            .receivedRequests
+            .get(1)
+            .getStateKey()
+            .getBagUserState()
+            .getUserStateId());
+    assertEquals(
+        stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(),
+        out.toByteString());
+    assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
+
+    assertEquals(
+        stateId2,
+        stateRequestHandler
+            .receivedRequests
+            .get(2)
+            .getStateKey()
+            .getBagUserState()
+            .getUserStateId());
+    assertEquals(
+        stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(),
+        out.toByteString());
+    assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
+  }
+
+  /**
+   * A state handler that stores each state request made - used to validate that cached requests
are
+   * not forwarded to the state client.
+   */
+  private static class StoringStateRequestHandler implements StateRequestHandler {
+
+    private StateRequestHandler stateRequestHandler;
+    private ArrayList<BeamFnApi.StateRequest> receivedRequests;
+    private ArrayList<BeamFnApi.ProcessBundleRequest.CacheToken> cacheTokens;
+
+    StoringStateRequestHandler(StateRequestHandler delegate) {
+      stateRequestHandler = delegate;
+      receivedRequests = new ArrayList<>();
+      cacheTokens = new ArrayList<>();
+    }
+
+    @Override
+    public CompletionStage<BeamFnApi.StateResponse.Builder> handle(BeamFnApi.StateRequest
request)
+        throws Exception {
+      receivedRequests.add(request);
+      return stateRequestHandler.handle(request);
+    }
+
+    @Override
+    public Iterable<BeamFnApi.ProcessBundleRequest.CacheToken> getCacheTokens() {
+      return Iterables.concat(stateRequestHandler.getCacheTokens(), cacheTokens);
+    }
+
+    public int getRequestCount() {
+      return receivedRequests.size();
+    }
+
+    public void addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken token) {
+      cacheTokens.add(token);
+    }
+  }
+
+  @Test
   public void testExecutionWithTimer() throws Exception {
     Pipeline p = Pipeline.create();
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index b40c9d5..e1d6ad4 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -24,6 +24,7 @@ import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -50,6 +51,7 @@ import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
 import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
+import org.apache.beam.fn.harness.state.CachingBeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
@@ -74,6 +76,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.data.LogicalEndpoint;
 import org.apache.beam.sdk.function.ThrowingRunnable;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
 import org.apache.beam.sdk.util.common.ReflectHelpers;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
@@ -140,10 +143,29 @@ public class ProcessBundleHandler {
     REGISTERED_RUNNER_FACTORIES = builder.build();
   }
 
+  // Creates a new map of state data for newly encountered state keys
+  private CacheLoader<
+          BeamFnApi.StateKey,
+          Map<CachingBeamFnStateClient.StateCacheKey, BeamFnApi.StateGetResponse>>
+      stateKeyMapCacheLoader =
+          new CacheLoader<
+              BeamFnApi.StateKey,
+              Map<CachingBeamFnStateClient.StateCacheKey, BeamFnApi.StateGetResponse>>()
{
+            @Override
+            public Map<CachingBeamFnStateClient.StateCacheKey, BeamFnApi.StateGetResponse>
load(
+                BeamFnApi.StateKey key) {
+              return new HashMap<>();
+            }
+          };
+
   private final PipelineOptions options;
   private final Function<String, Message> fnApiRegistry;
   private final BeamFnDataClient beamFnDataClient;
   private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache;
+  private final LoadingCache<
+          BeamFnApi.StateKey,
+          Map<CachingBeamFnStateClient.StateCacheKey, BeamFnApi.StateGetResponse>>
+      stateCache;
   private final FinalizeBundleHandler finalizeBundleHandler;
   private final ShortIdMap shortIds;
   private final boolean runnerAcceptsShortIds;
@@ -186,6 +208,7 @@ public class ProcessBundleHandler {
     this.fnApiRegistry = fnApiRegistry;
     this.beamFnDataClient = beamFnDataClient;
     this.beamFnStateGrpcClientCache = beamFnStateGrpcClientCache;
+    this.stateCache = CacheBuilder.newBuilder().build(stateKeyMapCacheLoader);
     this.finalizeBundleHandler = finalizeBundleHandler;
     this.shortIds = shortIds;
     this.runnerAcceptsShortIds =
@@ -491,14 +514,27 @@ public class ProcessBundleHandler {
       }
     }
 
-    // Instantiate a State API call handler depending on whether a State ApiServiceDescriptor
-    // was specified.
-    HandleStateCallsForBundle beamFnStateClient =
-        bundleDescriptor.hasStateApiServiceDescriptor()
-            ? new BlockTillStateCallsFinish(
-                beamFnStateGrpcClientCache.forApiServiceDescriptor(
-                    bundleDescriptor.getStateApiServiceDescriptor()))
-            : new FailAllStateCallsForBundle(processBundleRequest);
+    // Instantiate a State API call handler depending on whether a State ApiServiceDescriptor
was
+    // specified.
+    HandleStateCallsForBundle beamFnStateClient;
+    if (bundleDescriptor.hasStateApiServiceDescriptor()) {
+      BeamFnStateClient underlyingClient =
+          beamFnStateGrpcClientCache.forApiServiceDescriptor(
+              bundleDescriptor.getStateApiServiceDescriptor());
+
+      // If pipeline is batch, use a CachingBeamFnStateClient to store state responses.
+      // Once streaming is supported, always use CachingBeamFnStateClient as the arg
+      // to BlockTillStateCallsFinish
+      beamFnStateClient =
+          new BlockTillStateCallsFinish(
+              options.as(StreamingOptions.class).isStreaming()
+                  ? underlyingClient
+                  : new CachingBeamFnStateClient(
+                      underlyingClient, stateCache, processBundleRequest.getCacheTokensList()));
+
+    } else {
+      beamFnStateClient = new FailAllStateCallsForBundle(processBundleRequest);
+    }
 
     // Instantiate a Timer client registration handler depending on whether a Timer
     // ApiServiceDescriptor was specified.

Mime
View raw message