beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lc...@apache.org
Subject [beam] branch master updated: Add QueueingBeamFnDataClient and make process, finish and start run on the same thread to support metrics. (#6786)
Date Fri, 07 Dec 2018 01:37:44 GMT
This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4cd1226  Add QueueingBeamFnDataClient and make process, finish and start run on the
same thread to support metrics. (#6786)
4cd1226 is described below

commit 4cd12262dc22b765a48d2f6ecc1d3ca1ce43a1c9
Author: Alex Amato <ajamato@google.com>
AuthorDate: Thu Dec 6 17:37:36 2018 -0800

    Add QueueingBeamFnDataClient and make process, finish and start run on the same thread
to support metrics. (#6786)
---
 .../fnexecution/control/RemoteExecutionTest.java   |  86 +++++
 .../fn/harness/control/ProcessBundleHandler.java   |  18 +-
 .../fn/harness/data/QueueingBeamFnDataClient.java  | 182 +++++++++++
 .../harness/data/QueueingBeamFnDataClientTest.java | 361 +++++++++++++++++++++
 4 files changed, 646 insertions(+), 1 deletion(-)

diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 748efb4..f53257f 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkState;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import com.google.common.base.Optional;
@@ -279,6 +280,91 @@ public class RemoteExecutionTest implements Serializable {
   }
 
   @Test
+  public void testBundleProcessorThrowsExecutionExceptionWhenUserCodeThrows() throws Exception
{
+    Pipeline p = Pipeline.create();
+    p.apply("impulse", Impulse.create())
+        .apply(
+            "create",
+            ParDo.of(
+                new DoFn<byte[], KV<String, String>>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) throws Exception {
+                    String element =
+                        CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), ctxt.element());
+                    if (element.equals("X")) {
+                      throw new Exception("testBundleExecutionFailure");
+                    }
+                    ctxt.output(KV.of(element, element));
+                  }
+                }))
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    checkState(fused.getFusedStages().size() == 1, "Expected exactly one fused stage");
+    ExecutableStage stage = fused.getFusedStages().iterator().next();
+
+    ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "my_stage", stage, dataServer.getApiServiceDescriptor());
+
+    BundleProcessor processor =
+        controlClient.getProcessor(
+            descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations());
+    Map<Target, ? super Coder<WindowedValue<?>>> outputTargets = descriptor.getOutputTargetCoders();
+    Map<Target, Collection<? super WindowedValue<?>>> outputValues = new
HashMap<>();
+    Map<Target, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+    for (Entry<Target, ? super Coder<WindowedValue<?>>> targetCoder : outputTargets.entrySet())
{
+      List<? super WindowedValue<?>> outputContents =
+          Collections.synchronizedList(new ArrayList<>());
+      outputValues.put(targetCoder.getKey(), outputContents);
+      outputReceivers.put(
+          targetCoder.getKey(),
+          RemoteOutputReceiver.of(
+              (Coder) targetCoder.getValue(),
+              (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
+    }
+
+    try (ActiveBundle bundle =
+        processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(
+              WindowedValue.valueInGlobalWindow(
+                  CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y")));
+    }
+
+    try {
+      try (ActiveBundle bundle =
+          processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
+        Iterables.getOnlyElement(bundle.getInputReceivers().values())
+            .accept(
+                WindowedValue.valueInGlobalWindow(
+                    CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
+      }
+      // Fail the test if we reach this point and never threw the exception.
+      fail();
+    } catch (ExecutionException e) {
+      assertTrue(e.getMessage().contains("testBundleExecutionFailure"));
+    }
+
+    try (ActiveBundle bundle =
+        processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(
+              WindowedValue.valueInGlobalWindow(
+                  CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Z")));
+    }
+
+    for (Collection<? super WindowedValue<?>> windowedValues : outputValues.values())
{
+      assertThat(
+          windowedValues,
+          containsInAnyOrder(
+              WindowedValue.valueInGlobalWindow(kvBytes("Y", "Y")),
+              WindowedValue.valueInGlobalWindow(kvBytes("Z", "Z"))));
+    }
+  }
+
+  @Test
   public void testExecutionWithSideInput() throws Exception {
     Pipeline p = Pipeline.create();
     PCollection<String> input =
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index 3c9d183..e22a6c2 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -41,6 +41,7 @@ import java.util.function.Supplier;
 import org.apache.beam.fn.harness.PTransformRunnerFactory;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.Registrar;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
+import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
@@ -138,6 +139,7 @@ public class ProcessBundleHandler {
 
   private void createRunnerAndConsumersForPTransformRecursively(
       BeamFnStateClient beamFnStateClient,
+      BeamFnDataClient queueingClient,
       String pTransformId,
       PTransform pTransform,
       Supplier<String> processBundleInstructionId,
@@ -158,6 +160,7 @@ public class ProcessBundleHandler {
       for (String consumingPTransformId : pCollectionIdsToConsumingPTransforms.get(pCollectionId))
{
         createRunnerAndConsumersForPTransformRecursively(
             beamFnStateClient,
+            queueingClient,
             consumingPTransformId,
             processBundleDescriptor.getTransformsMap().get(consumingPTransformId),
             processBundleInstructionId,
@@ -188,7 +191,7 @@ public class ProcessBundleHandler {
           .getOrDefault(pTransform.getSpec().getUrn(), defaultPTransformRunnerFactory)
           .createRunnerForPTransform(
               options,
-              beamFnDataClient,
+              queueingClient,
               beamFnStateClient,
               pTransformId,
               pTransform,
@@ -204,8 +207,17 @@ public class ProcessBundleHandler {
     }
   }
 
+  /**
+   * Processes a bundle, running the start(), process(), and finish() functions. This function
is
+   * required to be reentrant.
+   */
   public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest
request)
       throws Exception {
+    // Note: We must create one instance of the QueueingBeamFnDataClient as it is designed
to
+    // handle the life of a bundle. It will insert elements onto a queue and drain them off
so all
+    // process() calls will execute on this thread when queueingClient.drainAndBlock() is
called.
+    QueueingBeamFnDataClient queueingClient = new QueueingBeamFnDataClient(this.beamFnDataClient);
+
     String bundleId = request.getProcessBundle().getProcessBundleDescriptorReference();
     BeamFnApi.ProcessBundleDescriptor bundleDescriptor =
         (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId);
@@ -255,6 +267,7 @@ public class ProcessBundleHandler {
       // Create a BeamFnStateClient
       for (Map.Entry<String, RunnerApi.PTransform> entry :
           bundleDescriptor.getTransformsMap().entrySet()) {
+
         // Skip anything which isn't a root
         // TODO: Remove source as a root and have it be triggered by the Runner.
         if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn())
@@ -266,6 +279,7 @@ public class ProcessBundleHandler {
 
         createRunnerAndConsumersForPTransformRecursively(
             beamFnStateClient,
+            queueingClient,
             entry.getKey(),
             entry.getValue(),
             request::getInstructionId,
@@ -284,6 +298,8 @@ public class ProcessBundleHandler {
         startFunction.run();
       }
 
+      queueingClient.drainAndBlock();
+
       // Need to reverse this since we want to call finish in topological order.
       for (ThrowingRunnable finishFunction : Lists.reverse(finishFunctions)) {
         LOG.debug("Finishing function {}", finishFunction);
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java
new file mode 100644
index 0000000..194672d
--- /dev/null
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClient.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.data;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.fn.harness.control.ProcessBundleHandler;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.data.InboundDataClient;
+import org.apache.beam.sdk.fn.data.LogicalEndpoint;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A {@link BeamFnDataClient} that queues elements so that they can be consumed and processed
in the
+ * thread which calls @{link #drainAndBlock}.
+ */
+public class QueueingBeamFnDataClient implements BeamFnDataClient {
+
+  private static final Logger LOG = LoggerFactory.getLogger(QueueingBeamFnDataClient.class);
+
+  private final BeamFnDataClient mainClient;
+  private final SynchronousQueue<ConsumerAndData> queue;
+  private final ConcurrentHashMap<InboundDataClient, Object> inboundDataClients;
+
+  public QueueingBeamFnDataClient(BeamFnDataClient mainClient) {
+    this.mainClient = mainClient;
+    this.queue = new SynchronousQueue<>();
+    this.inboundDataClients = new ConcurrentHashMap<>();
+  }
+
+  @Override
+  public <T> InboundDataClient receive(
+      ApiServiceDescriptor apiServiceDescriptor,
+      LogicalEndpoint inputLocation,
+      Coder<WindowedValue<T>> coder,
+      FnDataReceiver<WindowedValue<T>> consumer) {
+    LOG.debug(
+        "Registering consumer for instruction {} and target {}",
+        inputLocation.getInstructionId(),
+        inputLocation.getTarget());
+
+    QueueingFnDataReceiver<T> queueingConsumer = new QueueingFnDataReceiver<T>(consumer);
+    InboundDataClient inboundDataClient =
+        this.mainClient.receive(apiServiceDescriptor, inputLocation, coder, queueingConsumer);
+    queueingConsumer.inboundDataClient = inboundDataClient;
+    this.inboundDataClients.computeIfAbsent(
+        inboundDataClient, (InboundDataClient idcToStore) -> idcToStore);
+    return inboundDataClient;
+  }
+
+  // Returns true if all the InboundDataClients have finished or cancelled.
+  private boolean allDone() {
+    for (InboundDataClient inboundDataClient : inboundDataClients.keySet()) {
+      if (!inboundDataClient.isDone()) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /**
+   * Drains the internal queue of this class, by waiting for all WindowedValues to be passed
to
+   * their consumers. The thread which wishes to process() the elements should call this
method, as
+   * this will cause the consumers to invoke element processing. All receive() and send()
calls must
+   * be made prior to calling drainAndBlock, in order to properly terminate.
+   *
+   * <p>All {@link InboundDataClient}s will be failed if processing throws an exception.
+   *
+   * <p>This method is NOT thread safe. This should only be invoked by a single thread,
and is
+   * intended for use with a newly constructed QueueingBeamFnDataClient in {@link
+   * ProcessBundleHandler#processBundle(InstructionRequest)}.
+   */
+  public void drainAndBlock() throws Exception {
+    while (true) {
+      try {
+        ConsumerAndData tuple = queue.poll(200, TimeUnit.MILLISECONDS);
+        if (tuple != null) {
+          // Forward to the consumers who cares about this data.
+          tuple.consumer.accept(tuple.data);
+        } else {
+          // Note: We do not expect to ever hit this point without receiving all values
+          // as (1) The InboundObserver will not be set to Done until the
+          // QueuingFnDataReceiver.accept() call returns and will not be invoked again.
+          // (2) The QueueingFnDataReceiver will not return until the value is received in
+          // drainAndBlock, because of the use of the SynchronousQueue.
+          if (allDone()) {
+            break;
+          }
+        }
+      } catch (Exception e) {
+        LOG.error("Client failed to dequeue and process WindowedValue", e);
+        for (InboundDataClient inboundDataClient : inboundDataClients.keySet()) {
+          inboundDataClient.fail(e);
+        }
+        throw e;
+      }
+    }
+  }
+
+  @Override
+  public <T> CloseableFnDataReceiver<WindowedValue<T>> send(
+      Endpoints.ApiServiceDescriptor apiServiceDescriptor,
+      LogicalEndpoint outputLocation,
+      Coder<WindowedValue<T>> coder) {
+    LOG.debug(
+        "Creating output consumer for instruction {} and target {}",
+        outputLocation.getInstructionId(),
+        outputLocation.getTarget());
+    return this.mainClient.send(apiServiceDescriptor, outputLocation, coder);
+  }
+
+  /**
+   * The QueueingFnDataReceiver is a a FnDataReceiver used by the QueueingBeamFnDataClient.
+   *
+   * <p>All {@link #accept accept()ed} values will be put onto a synchronous queue
which will cause
+   * the calling thread to block until {@link QueueingBeamFnDataClient#drainAndBlock} is
called.
+   * {@link QueueingBeamFnDataClient#drainAndBlock} is responsible for processing values
from the
+   * queue.
+   */
+  public class QueueingFnDataReceiver<T> implements FnDataReceiver<WindowedValue<T>>
{
+    private final FnDataReceiver<WindowedValue<T>> consumer;
+    public InboundDataClient inboundDataClient;
+
+    public QueueingFnDataReceiver(FnDataReceiver<WindowedValue<T>> consumer)
{
+      this.consumer = consumer;
+    }
+
+    /**
+     * This method is thread safe, we expect multiple threads to call this, passing in data
when new
+     * data arrives via the QueueingBeamFnDataClient's mainClient.
+     */
+    @Override
+    public void accept(WindowedValue<T> value) throws Exception {
+      try {
+        ConsumerAndData offering = new ConsumerAndData(this.consumer, value);
+        while (!queue.offer(offering, 200, TimeUnit.MILLISECONDS)) {
+          if (inboundDataClient.isDone()) {
+            // If it was cancelled by the consuming side of the queue.
+            break;
+          }
+        }
+      } catch (Exception e) {
+        LOG.error("Failed to insert WindowedValue into the queue", e);
+        inboundDataClient.fail(e);
+        throw e;
+      }
+    }
+  }
+
+  static class ConsumerAndData<T> {
+    public FnDataReceiver<WindowedValue<T>> consumer;
+    public WindowedValue<T> data;
+
+    public ConsumerAndData(FnDataReceiver<WindowedValue<T>> receiver, WindowedValue<T>
data) {
+      this.consumer = receiver;
+      this.data = data;
+    }
+  }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClientTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClientTest.java
new file mode 100644
index 0000000..3bb77f7
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/QueueingBeamFnDataClientTest.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.data;
+
+import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.hamcrest.Matchers.contains;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.util.Collection;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.Target;
+import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc;
+import org.apache.beam.model.pipeline.v1.Endpoints;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.LengthPrefixCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.fn.data.InboundDataClient;
+import org.apache.beam.sdk.fn.data.LogicalEndpoint;
+import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
+import org.apache.beam.sdk.fn.test.TestExecutors;
+import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
+import org.apache.beam.sdk.fn.test.TestStreams;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.grpc.v1_13_1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1_13_1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1_13_1.io.grpc.Server;
+import org.apache.beam.vendor.grpc.v1_13_1.io.grpc.inprocess.InProcessChannelBuilder;
+import org.apache.beam.vendor.grpc.v1_13_1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1_13_1.io.grpc.stub.CallStreamObserver;
+import org.apache.beam.vendor.grpc.v1_13_1.io.grpc.stub.StreamObserver;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Tests for {@link QueueingBeamFnDataClient}. */
+@RunWith(JUnit4.class)
+public class QueueingBeamFnDataClientTest {
+
+  private static final Logger LOG = LoggerFactory.getLogger(QueueingBeamFnDataClientTest.class);
+
+  @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
+
+  private static final Coder<WindowedValue<String>> CODER =
+      LengthPrefixCoder.of(
+          WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE));
+  private static final LogicalEndpoint ENDPOINT_A =
+      LogicalEndpoint.of(
+          "12L",
+          Target.newBuilder().setPrimitiveTransformReference("34L").setName("targetA").build());
+
+  private static final LogicalEndpoint ENDPOINT_B =
+      LogicalEndpoint.of(
+          "56L",
+          BeamFnApi.Target.newBuilder()
+              .setPrimitiveTransformReference("78L")
+              .setName("targetB")
+              .build());
+
+  private static final BeamFnApi.Elements ELEMENTS_A_1;
+  private static final BeamFnApi.Elements ELEMENTS_A_2;
+  private static final BeamFnApi.Elements ELEMENTS_B_1;
+
+  static {
+    try {
+      ELEMENTS_A_1 =
+          BeamFnApi.Elements.newBuilder()
+              .addData(
+                  BeamFnApi.Elements.Data.newBuilder()
+                      .setInstructionReference(ENDPOINT_A.getInstructionId())
+                      .setTarget(ENDPOINT_A.getTarget())
+                      .setData(
+                          ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("ABC")))
+                              .concat(
+                                  ByteString.copyFrom(
+                                      encodeToByteArray(CODER, valueInGlobalWindow("DEF"))))))
+              .build();
+      ELEMENTS_A_2 =
+          BeamFnApi.Elements.newBuilder()
+              .addData(
+                  BeamFnApi.Elements.Data.newBuilder()
+                      .setInstructionReference(ENDPOINT_A.getInstructionId())
+                      .setTarget(ENDPOINT_A.getTarget())
+                      .setData(
+                          ByteString.copyFrom(
+                              encodeToByteArray(CODER, valueInGlobalWindow("GHI")))))
+              .addData(
+                  BeamFnApi.Elements.Data.newBuilder()
+                      .setInstructionReference(ENDPOINT_A.getInstructionId())
+                      .setTarget(ENDPOINT_A.getTarget()))
+              .build();
+      ELEMENTS_B_1 =
+          BeamFnApi.Elements.newBuilder()
+              .addData(
+                  BeamFnApi.Elements.Data.newBuilder()
+                      .setInstructionReference(ENDPOINT_B.getInstructionId())
+                      .setTarget(ENDPOINT_B.getTarget())
+                      .setData(
+                          ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("JKL")))
+                              .concat(
+                                  ByteString.copyFrom(
+                                      encodeToByteArray(CODER, valueInGlobalWindow("MNO"))))))
+              .addData(
+                  BeamFnApi.Elements.Data.newBuilder()
+                      .setInstructionReference(ENDPOINT_B.getInstructionId())
+                      .setTarget(ENDPOINT_B.getTarget()))
+              .build();
+    } catch (Exception e) {
+      throw new ExceptionInInitializerError(e);
+    }
+  }
+
+  @Test(timeout = 10000)
+  public void testBasicInboundConsumerBehaviour() throws Exception {
+    CountDownLatch waitForClientToConnect = new CountDownLatch(1);
+    CountDownLatch receiveAllValuesA = new CountDownLatch(3);
+    CountDownLatch receiveAllValuesB = new CountDownLatch(2);
+    Collection<WindowedValue<String>> inboundValuesA = new ConcurrentLinkedQueue<>();
+    Collection<WindowedValue<String>> inboundValuesB = new ConcurrentLinkedQueue<>();
+    Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
+    AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver
=
+        new AtomicReference<>();
+    CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
+        TestStreams.withOnNext(inboundServerValues::add).build();
+
+    Endpoints.ApiServiceDescriptor apiServiceDescriptor =
+        Endpoints.ApiServiceDescriptor.newBuilder()
+            .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
+            .build();
+    Server server =
+        InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
+            .addService(
+                new BeamFnDataGrpc.BeamFnDataImplBase() {
+                  @Override
+                  public StreamObserver<BeamFnApi.Elements> data(
+                      StreamObserver<BeamFnApi.Elements> outboundObserver) {
+                    outboundServerObserver.set(outboundObserver);
+                    waitForClientToConnect.countDown();
+                    return inboundServerObserver;
+                  }
+                })
+            .build();
+    server.start();
+    try {
+      ManagedChannel channel =
+          InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
+
+      BeamFnDataGrpcClient clientFactory =
+          new BeamFnDataGrpcClient(
+              PipelineOptionsFactory.create(),
+              (Endpoints.ApiServiceDescriptor descriptor) -> channel,
+              OutboundObserverFactory.trivial());
+      QueueingBeamFnDataClient queueingClient = new QueueingBeamFnDataClient(clientFactory);
+
+      InboundDataClient readFutureA =
+          queueingClient.receive(
+              apiServiceDescriptor,
+              ENDPOINT_A,
+              CODER,
+              (WindowedValue<String> wv) -> {
+                inboundValuesA.add(wv);
+                receiveAllValuesA.countDown();
+              });
+
+      waitForClientToConnect.await();
+
+      Future<?> sendElementsFuture =
+          executor.submit(
+              () -> {
+                outboundServerObserver.get().onNext(ELEMENTS_A_1);
+                // Purposefully transmit some data before the consumer for B is bound showing
that
+                // data is not lost
+                outboundServerObserver.get().onNext(ELEMENTS_B_1);
+              });
+
+      // This can be compeleted before we get values?
+      InboundDataClient readFutureB =
+          queueingClient.receive(
+              apiServiceDescriptor,
+              ENDPOINT_B,
+              CODER,
+              (WindowedValue<String> wv) -> {
+                inboundValuesB.add(wv);
+                receiveAllValuesB.countDown();
+              });
+
+      Future<?> drainElementsFuture =
+          executor.submit(
+              () -> {
+                try {
+                  queueingClient.drainAndBlock();
+                } catch (Exception e) {
+                  LOG.error("Failed ", e);
+                  fail();
+                }
+              });
+
+      receiveAllValuesB.await();
+      assertThat(inboundValuesB, contains(valueInGlobalWindow("JKL"), valueInGlobalWindow("MNO")));
+
+      outboundServerObserver.get().onNext(ELEMENTS_A_2);
+
+      receiveAllValuesA.await(); // Wait for A's values to be available
+      assertThat(
+          inboundValuesA,
+          contains(
+              valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"), valueInGlobalWindow("GHI")));
+
+      // Wait for these threads to terminate
+      sendElementsFuture.get();
+      drainElementsFuture.get();
+    } finally {
+      server.shutdownNow();
+    }
+  }
+
+  @Test(timeout = 10000)
+  public void testBundleProcessorThrowsExecutionExceptionWhenUserCodeThrows() throws Exception
{
+    CountDownLatch waitForClientToConnect = new CountDownLatch(1);
+    //Collection<WindowedValue<String>> inboundValuesA = new ConcurrentLinkedQueue<>();
+    Collection<WindowedValue<String>> inboundValuesB = new ConcurrentLinkedQueue<>();
+    Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
+    AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver
=
+        new AtomicReference<>();
+    CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
+        TestStreams.withOnNext(inboundServerValues::add).build();
+
+    Endpoints.ApiServiceDescriptor apiServiceDescriptor =
+        Endpoints.ApiServiceDescriptor.newBuilder()
+            .setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
+            .build();
+    Server server =
+        InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
+            .addService(
+                new BeamFnDataGrpc.BeamFnDataImplBase() {
+                  @Override
+                  public StreamObserver<BeamFnApi.Elements> data(
+                      StreamObserver<BeamFnApi.Elements> outboundObserver) {
+                    outboundServerObserver.set(outboundObserver);
+                    waitForClientToConnect.countDown();
+                    return inboundServerObserver;
+                  }
+                })
+            .build();
+    server.start();
+    try {
+      ManagedChannel channel =
+          InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
+
+      BeamFnDataGrpcClient clientFactory =
+          new BeamFnDataGrpcClient(
+              PipelineOptionsFactory.create(),
+              (Endpoints.ApiServiceDescriptor descriptor) -> channel,
+              OutboundObserverFactory.trivial());
+      QueueingBeamFnDataClient queueingClient = new QueueingBeamFnDataClient(clientFactory);
+
+      InboundDataClient readFutureA =
+          queueingClient.receive(
+              apiServiceDescriptor,
+              ENDPOINT_A,
+              CODER,
+              (WindowedValue<String> wv) -> {
+                throw new RuntimeException("Intentionally fail!"); // Error injected here.
+              });
+
+      waitForClientToConnect.await();
+
+      Future<?> sendElementsFuture =
+          executor.submit(
+              () -> {
+                outboundServerObserver.get().onNext(ELEMENTS_A_1);
+                // Purposefully transmit some data before the consumer for B is bound showing
that
+                // data is not lost
+                outboundServerObserver.get().onNext(ELEMENTS_B_1);
+              });
+
+      InboundDataClient readFutureB =
+          queueingClient.receive(
+              apiServiceDescriptor,
+              ENDPOINT_B,
+              CODER,
+              (WindowedValue<String> wv) -> {
+                inboundValuesB.add(wv);
+              });
+
+      Future<?> drainElementsFuture =
+          executor.submit(
+              () -> {
+                boolean intentionallyFailed = false;
+                try {
+                  queueingClient.drainAndBlock();
+                } catch (RuntimeException e) {
+                  intentionallyFailed = true;
+                } catch (Exception e) {
+                  LOG.error("Unintentional failure", e);
+                  fail();
+                }
+                assertTrue(intentionallyFailed);
+              });
+
+      // Fail all InboundObservers if any of the downstream consumers fail.
+      // This allows the ProcessBundlerHandler to unblock everything and fail properly.
+      boolean intentionallyFailedA = false;
+      try {
+        readFutureA.awaitCompletion();
+      } catch (ExecutionException e) {
+        if (e.getCause() instanceof RuntimeException) {
+          intentionallyFailedA = true;
+        }
+      }
+      assertTrue(intentionallyFailedA);
+
+      boolean intentionallyFailedB = false;
+      try {
+        readFutureB.awaitCompletion();
+      } catch (ExecutionException e) {
+        if (e.getCause() instanceof RuntimeException) {
+          intentionallyFailedB = true;
+        }
+      } catch (Exception e) {
+        intentionallyFailedB = true;
+      }
+      assertTrue(intentionallyFailedB);
+
+      // Wait for these threads to terminate
+      sendElementsFuture.get();
+      drainElementsFuture.get();
+    } finally {
+      server.shutdownNow();
+    }
+  }
+}


Mime
View raw message