beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lc...@apache.org
Subject [1/3] beam git commit: [BEAM-1347] Break apart ProcessBundleHandler to use service locator pattern based upon URNs.
Date Fri, 16 Jun 2017 19:03:47 GMT
Repository: beam
Updated Branches:
  refs/heads/master 54f307891 -> aa555f593


http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java
index a3c874e..64d9ea6 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BeamFnDataWriteRunnerTest.java
@@ -20,31 +20,48 @@ package org.apache.beam.runners.core;
 
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
 import com.google.protobuf.Any;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.BytesValue;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.fn.CloseableThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
 import org.apache.beam.fn.v1.BeamFnApi;
+import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar;
 import org.apache.beam.runners.dataflow.util.CloudObjects;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
+import org.hamcrest.collection.IsMapContaining;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -56,15 +73,18 @@ import org.mockito.MockitoAnnotations;
 /** Tests for {@link BeamFnDataWriteRunner}. */
 @RunWith(JUnit4.class)
 public class BeamFnDataWriteRunnerTest {
-  private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
 
+  private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
   private static final BeamFnApi.RemoteGrpcPort PORT_SPEC = BeamFnApi.RemoteGrpcPort.newBuilder()
       .setApiServiceDescriptor(BeamFnApi.ApiServiceDescriptor.getDefaultInstance()).build();
   private static final RunnerApi.FunctionSpec FUNCTION_SPEC = RunnerApi.FunctionSpec.newBuilder()
       .setParameter(Any.pack(PORT_SPEC)).build();
+  private static final String CODER_ID = "string-coder-id";
   private static final Coder<WindowedValue<String>> CODER =
       WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
   private static final RunnerApi.Coder CODER_SPEC;
+  private static final String URN = "urn:org.apache.beam:sink:runner:0.1";
+
   static {
     try {
       CODER_SPEC = RunnerApi.Coder.newBuilder().setSpec(
@@ -85,18 +105,93 @@ public class BeamFnDataWriteRunnerTest {
       .setName("out")
       .build();
 
-  @Mock private BeamFnDataClient mockBeamFnDataClientFactory;
+  @Mock private BeamFnDataClient mockBeamFnDataClient;
 
   @Before
   public void setUp() {
     MockitoAnnotations.initMocks(this);
   }
 
+
+  @Test
+  public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception {
+    String bundleId = "57L";
+    String inputId = "100L";
+
+    Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+        .setUrn("urn:org.apache.beam:sink:runner:0.1")
+        .setParameter(Any.pack(PORT_SPEC))
+        .build();
+
+    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+        .setSpec(functionSpec)
+        .putInputs(inputId, "inputPC")
+        .build();
+
+    new BeamFnDataWriteRunner.Factory<String>().createRunnerForPTransform(
+        PipelineOptionsFactory.create(),
+        mockBeamFnDataClient,
+        "ptransformId",
+        pTransform,
+        Suppliers.ofInstance(bundleId)::get,
+        ImmutableMap.of("inputPC",
+            RunnerApi.PCollection.newBuilder().setCoderId(CODER_ID).build()),
+        ImmutableMap.of(CODER_ID, CODER_SPEC),
+        consumers,
+        startFunctions::add,
+        finishFunctions::add);
+
+    verifyZeroInteractions(mockBeamFnDataClient);
+
+    List<WindowedValue<String>> outputValues = new ArrayList<>();
+    AtomicBoolean wasCloseCalled = new AtomicBoolean();
+    CloseableThrowingConsumer<WindowedValue<String>> outputConsumer =
+        new CloseableThrowingConsumer<WindowedValue<String>>(){
+          @Override
+          public void close() throws Exception {
+            wasCloseCalled.set(true);
+          }
+
+          @Override
+          public void accept(WindowedValue<String> t) throws Exception {
+            outputValues.add(t);
+          }
+        };
+
+    when(mockBeamFnDataClient.forOutboundConsumer(
+        any(),
+        any(),
+        Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(outputConsumer);
+    Iterables.getOnlyElement(startFunctions).run();
+    verify(mockBeamFnDataClient).forOutboundConsumer(
+        eq(PORT_SPEC.getApiServiceDescriptor()),
+        eq(KV.of(bundleId, BeamFnApi.Target.newBuilder()
+            .setPrimitiveTransformReference("ptransformId")
+            .setName(inputId)
+            .build())),
+        eq(CODER));
+
+    assertThat(consumers.keySet(), containsInAnyOrder("inputPC"));
+    Iterables.getOnlyElement(consumers.get("inputPC")).accept(valueInGlobalWindow("TestValue"));
+    assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
+    outputValues.clear();
+
+    assertFalse(wasCloseCalled.get());
+    Iterables.getOnlyElement(finishFunctions).run();
+    assertTrue(wasCloseCalled.get());
+
+    verifyNoMoreInteractions(mockBeamFnDataClient);
+  }
+
   @Test
   public void testReuseForMultipleBundles() throws Exception {
     RecordingConsumer<WindowedValue<String>> valuesA = new RecordingConsumer<>();
     RecordingConsumer<WindowedValue<String>> valuesB = new RecordingConsumer<>();
-    when(mockBeamFnDataClientFactory.forOutboundConsumer(
+    when(mockBeamFnDataClient.forOutboundConsumer(
         any(),
         any(),
         Matchers.<Coder<WindowedValue<String>>>any())).thenReturn(valuesA).thenReturn(valuesB);
@@ -106,12 +201,12 @@ public class BeamFnDataWriteRunnerTest {
         bundleId::get,
         OUTPUT_TARGET,
         CODER_SPEC,
-        mockBeamFnDataClientFactory);
+        mockBeamFnDataClient);
 
     // Process for bundle id 0
     writeRunner.registerForOutput();
 
-    verify(mockBeamFnDataClientFactory).forOutboundConsumer(
+    verify(mockBeamFnDataClient).forOutboundConsumer(
         eq(PORT_SPEC.getApiServiceDescriptor()),
         eq(KV.of(bundleId.get(), OUTPUT_TARGET)),
         eq(CODER));
@@ -129,7 +224,7 @@ public class BeamFnDataWriteRunnerTest {
     valuesB.clear();
     writeRunner.registerForOutput();
 
-    verify(mockBeamFnDataClientFactory).forOutboundConsumer(
+    verify(mockBeamFnDataClient).forOutboundConsumer(
         eq(PORT_SPEC.getApiServiceDescriptor()),
         eq(KV.of(bundleId.get(), OUTPUT_TARGET)),
         eq(CODER));
@@ -140,7 +235,7 @@ public class BeamFnDataWriteRunnerTest {
 
     assertTrue(valuesB.closed);
     assertThat(valuesB, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
-    verifyNoMoreInteractions(mockBeamFnDataClientFactory);
+    verifyNoMoreInteractions(mockBeamFnDataClient);
   }
 
   private static class RecordingConsumer<T> extends ArrayList<T>
@@ -158,6 +253,17 @@ public class BeamFnDataWriteRunnerTest {
       }
       add(t);
     }
+  }
 
+  @Test
+  public void testRegistration() {
+    for (Registrar registrar :
+        ServiceLoader.load(Registrar.class)) {
+      if (registrar instanceof BeamFnDataWriteRunner.Registrar) {
+        assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+        return;
+      }
+    }
+    fail("Expected registrar not found.");
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java
index d8ed121..6c9a4cb 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/BoundedSourceRunnerTest.java
@@ -20,25 +20,35 @@ package org.apache.beam.runners.core;
 
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.collection.IsEmptyCollection.empty;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
 
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
 import com.google.protobuf.Any;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.BytesValue;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
-import java.util.Map;
+import java.util.ServiceLoader;
 import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.CountingSource;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.hamcrest.Matchers;
+import org.hamcrest.collection.IsMapContaining;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -46,27 +56,25 @@ import org.junit.runners.JUnit4;
 /** Tests for {@link BoundedSourceRunner}. */
 @RunWith(JUnit4.class)
 public class BoundedSourceRunnerTest {
+
+  public static final String URN = "urn:org.apache.beam:source:java:0.1";
+
   @Test
   public void testRunReadLoopWithMultipleSources() throws Exception {
-    List<WindowedValue<Long>> out1ValuesA = new ArrayList<>();
-    List<WindowedValue<Long>> out1ValuesB = new ArrayList<>();
+    List<WindowedValue<Long>> out1Values = new ArrayList<>();
     List<WindowedValue<Long>> out2Values = new ArrayList<>();
-    Map<String, Collection<ThrowingConsumer<WindowedValue<Long>>>>
outputMap = ImmutableMap.of(
-        "out1", ImmutableList.of(out1ValuesA::add, out1ValuesB::add),
-        "out2", ImmutableList.of(out2Values::add));
+    Collection<ThrowingConsumer<WindowedValue<Long>>> consumers =
+        ImmutableList.of(out1Values::add, out2Values::add);
 
-    BoundedSourceRunner<BoundedSource<Long>, Long> runner =
-        new BoundedSourceRunner<>(
+    BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>(
         PipelineOptionsFactory.create(),
         RunnerApi.FunctionSpec.getDefaultInstance(),
-        outputMap);
+        consumers);
 
     runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(2)));
     runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(1)));
 
-    assertThat(out1ValuesA,
-        contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L)));
-    assertThat(out1ValuesB,
+    assertThat(out1Values,
         contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L)));
     assertThat(out2Values,
         contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(0L)));
@@ -74,40 +82,106 @@ public class BoundedSourceRunnerTest {
 
   @Test
   public void testRunReadLoopWithEmptySource() throws Exception {
-    List<WindowedValue<Long>> out1Values = new ArrayList<>();
-    Map<String, Collection<ThrowingConsumer<WindowedValue<Long>>>>
outputMap = ImmutableMap.of(
-        "out1", ImmutableList.of(out1Values::add));
+    List<WindowedValue<Long>> outValues = new ArrayList<>();
+    Collection<ThrowingConsumer<WindowedValue<Long>>> consumers =
+        ImmutableList.of(outValues::add);
 
-    BoundedSourceRunner<BoundedSource<Long>, Long> runner =
-        new BoundedSourceRunner<>(
+    BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>(
         PipelineOptionsFactory.create(),
         RunnerApi.FunctionSpec.getDefaultInstance(),
-        outputMap);
+        consumers);
 
     runner.runReadLoop(valueInGlobalWindow(CountingSource.upTo(0)));
 
-    assertThat(out1Values, empty());
+    assertThat(outValues, empty());
   }
 
   @Test
   public void testStart() throws Exception {
     List<WindowedValue<Long>> outValues = new ArrayList<>();
-    Map<String, Collection<ThrowingConsumer<WindowedValue<Long>>>>
outputMap = ImmutableMap.of(
-        "out", ImmutableList.of(outValues::add));
+    Collection<ThrowingConsumer<WindowedValue<Long>>> consumers =
+        ImmutableList.of(outValues::add);
 
     ByteString encodedSource =
         ByteString.copyFrom(SerializableUtils.serializeToByteArray(CountingSource.upTo(3)));
 
-    BoundedSourceRunner<BoundedSource<Long>, Long> runner =
-        new BoundedSourceRunner<>(
+    BoundedSourceRunner<BoundedSource<Long>, Long> runner = new BoundedSourceRunner<>(
         PipelineOptionsFactory.create(),
-            RunnerApi.FunctionSpec.newBuilder().setParameter(
+        RunnerApi.FunctionSpec.newBuilder().setParameter(
             Any.pack(BytesValue.newBuilder().setValue(encodedSource).build())).build(),
-        outputMap);
+        consumers);
 
     runner.start();
 
     assertThat(outValues,
         contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L), valueInGlobalWindow(2L)));
   }
+
+  @Test
+  public void testCreatingAndProcessingSourceFromFactory() throws Exception {
+    List<WindowedValue<String>> outputValues = new ArrayList<>();
+
+    Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+    consumers.put("outputPC",
+        (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) outputValues::add);
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+        .setUrn("urn:org.apache.beam:source:java:0.1")
+        .setParameter(Any.pack(BytesValue.newBuilder()
+            .setValue(ByteString.copyFrom(
+                SerializableUtils.serializeToByteArray(CountingSource.upTo(3))))
+            .build()))
+        .build();
+
+    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+        .setSpec(functionSpec)
+        .putInputs("input", "inputPC")
+        .putOutputs("output", "outputPC")
+        .build();
+
+    new BoundedSourceRunner.Factory<>().createRunnerForPTransform(
+        PipelineOptionsFactory.create(),
+        null /* beamFnDataClient */,
+        "pTransformId",
+        pTransform,
+        Suppliers.ofInstance("57L")::get,
+        ImmutableMap.of(),
+        ImmutableMap.of(),
+        consumers,
+        startFunctions::add,
+        finishFunctions::add);
+
+    // This is testing a deprecated way of running sources and should be removed
+    // once all source definitions are instead propagated along the input edge.
+    Iterables.getOnlyElement(startFunctions).run();
+    assertThat(outputValues, contains(
+        valueInGlobalWindow(0L),
+        valueInGlobalWindow(1L),
+        valueInGlobalWindow(2L)));
+    outputValues.clear();
+
+    // Check that when passing a source along as an input, the source is processed.
+    assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC"));
+    Iterables.getOnlyElement(consumers.get("inputPC")).accept(
+        valueInGlobalWindow(CountingSource.upTo(2)));
+    assertThat(outputValues, contains(
+        valueInGlobalWindow(0L),
+        valueInGlobalWindow(1L)));
+
+    assertThat(finishFunctions, Matchers.empty());
+  }
+
+  @Test
+  public void testRegistration() {
+    for (Registrar registrar :
+        ServiceLoader.load(Registrar.class)) {
+      if (registrar instanceof BoundedSourceRunner.Registrar) {
+        assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+        return;
+      }
+    }
+    fail("Expected registrar not found.");
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/c9c1a05d/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java
b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java
new file mode 100644
index 0000000..62646ff
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/runners/core/DoFnRunnerFactoryTest.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core;
+
+import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Suppliers;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+import com.google.protobuf.Any;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.BytesValue;
+import com.google.protobuf.Message;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.ServiceLoader;
+import org.apache.beam.fn.harness.fn.ThrowingConsumer;
+import org.apache.beam.fn.harness.fn.ThrowingRunnable;
+import org.apache.beam.runners.core.PTransformRunnerFactory.Registrar;
+import org.apache.beam.runners.dataflow.util.CloudObjects;
+import org.apache.beam.runners.dataflow.util.DoFnInfo;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.common.runner.v1.RunnerApi;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.SerializableUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.hamcrest.collection.IsMapContaining;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link DoFnRunnerFactory}. */
+@RunWith(JUnit4.class)
+public class DoFnRunnerFactoryTest {
+
+  private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+  private static final Coder<WindowedValue<String>> STRING_CODER =
+      WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE);
+  private static final String STRING_CODER_SPEC_ID = "999L";
+  private static final RunnerApi.Coder STRING_CODER_SPEC;
+  private static final String URN = "urn:org.apache.beam:dofn:java:0.1";
+
+  static {
+    try {
+      STRING_CODER_SPEC = RunnerApi.Coder.newBuilder()
+          .setSpec(RunnerApi.SdkFunctionSpec.newBuilder()
+              .setSpec(RunnerApi.FunctionSpec.newBuilder()
+                  .setParameter(Any.pack(BytesValue.newBuilder().setValue(ByteString.copyFrom(
+                      OBJECT_MAPPER.writeValueAsBytes(CloudObjects.asCloudObject(STRING_CODER))))
+                      .build())))
+              .build())
+          .build();
+    } catch (IOException e) {
+      throw new ExceptionInInitializerError(e);
+    }
+  }
+
+  private static class TestDoFn extends DoFn<String, String> {
+    private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
+    private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
+
+    private BoundedWindow window;
+
+    @ProcessElement
+    public void processElement(ProcessContext context, BoundedWindow window) {
+      context.output("MainOutput" + context.element());
+      context.output(additionalOutput, "AdditionalOutput" + context.element());
+      this.window = window;
+    }
+
+    @FinishBundle
+    public void finishBundle(FinishBundleContext context) {
+      if (window != null) {
+        context.output("FinishBundle", window.maxTimestamp(), window);
+        window = null;
+      }
+    }
+  }
+
+  /**
+   * Create a DoFn that has 3 inputs (inputATarget1, inputATarget2, inputBTarget) and 2 outputs
+   * (mainOutput, output). Validate that inputs are fed to the {@link DoFn} and that outputs
+   * are directed to the correct consumers.
+   */
+  @Test
+  public void testCreatingAndProcessingDoFn() throws Exception {
+    Map<String, Message> fnApiRegistry = ImmutableMap.of(STRING_CODER_SPEC_ID, STRING_CODER_SPEC);
+    String pTransformId = "pTransformId";
+    String mainOutputId = "101";
+    String additionalOutputId = "102";
+
+    DoFnInfo<?, ?> doFnInfo = DoFnInfo.forFn(
+        new TestDoFn(),
+        WindowingStrategy.globalDefault(),
+        ImmutableList.of(),
+        StringUtf8Coder.of(),
+        Long.parseLong(mainOutputId),
+        ImmutableMap.of(
+            Long.parseLong(mainOutputId), TestDoFn.mainOutput,
+            Long.parseLong(additionalOutputId), TestDoFn.additionalOutput));
+    RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder()
+        .setUrn("urn:org.apache.beam:dofn:java:0.1")
+        .setParameter(Any.pack(BytesValue.newBuilder()
+            .setValue(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnInfo)))
+            .build()))
+        .build();
+    RunnerApi.PTransform pTransform = RunnerApi.PTransform.newBuilder()
+        .setSpec(functionSpec)
+        .putInputs("inputA", "inputATarget")
+        .putInputs("inputB", "inputBTarget")
+        .putOutputs(mainOutputId, "mainOutputTarget")
+        .putOutputs(additionalOutputId, "additionalOutputTarget")
+        .build();
+
+    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
+    Multimap<String, ThrowingConsumer<WindowedValue<?>>> consumers = HashMultimap.create();
+    consumers.put("mainOutputTarget",
+        (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) mainOutputValues::add);
+    consumers.put("additionalOutputTarget",
+        (ThrowingConsumer) (ThrowingConsumer<WindowedValue<String>>) additionalOutputValues::add);
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    new DoFnRunnerFactory.Factory<>().createRunnerForPTransform(
+        PipelineOptionsFactory.create(),
+        null /* beamFnDataClient */,
+        pTransformId,
+        pTransform,
+        Suppliers.ofInstance("57L")::get,
+        ImmutableMap.of(),
+        ImmutableMap.of(),
+        consumers,
+        startFunctions::add,
+        finishFunctions::add);
+
+    Iterables.getOnlyElement(startFunctions).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(
+        "inputATarget", "inputBTarget", "mainOutputTarget", "additionalOutputTarget"));
+
+    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A1"));
+    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("A2"));
+    Iterables.getOnlyElement(consumers.get("inputATarget")).accept(valueInGlobalWindow("B"));
+    assertThat(mainOutputValues, contains(
+        valueInGlobalWindow("MainOutputA1"),
+        valueInGlobalWindow("MainOutputA2"),
+        valueInGlobalWindow("MainOutputB")));
+    assertThat(additionalOutputValues, contains(
+        valueInGlobalWindow("AdditionalOutputA1"),
+        valueInGlobalWindow("AdditionalOutputA2"),
+        valueInGlobalWindow("AdditionalOutputB")));
+    mainOutputValues.clear();
+    additionalOutputValues.clear();
+
+    Iterables.getOnlyElement(finishFunctions).run();
+    assertThat(
+        mainOutputValues,
+        contains(
+            timestampedValueInGlobalWindow("FinishBundle", GlobalWindow.INSTANCE.maxTimestamp())));
+    mainOutputValues.clear();
+  }
+
+  @Test
+  public void testRegistration() {
+    for (Registrar registrar :
+        ServiceLoader.load(Registrar.class)) {
+      if (registrar instanceof DoFnRunnerFactory.Registrar) {
+        assertThat(registrar.getPTransformRunnerFactories(), IsMapContaining.hasKey(URN));
+        return;
+      }
+    }
+    fail("Expected registrar not found.");
+  }
+}


Mime
View raw message