beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [10/50] [abbrv] beam git commit: Update Signature of PTransformOverrideFactory
Date Tue, 25 Apr 2017 17:30:05 GMT
http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 684dc14..4eec6b8 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -61,6 +61,7 @@ import java.util.TreeSet;
 import org.apache.beam.runners.core.construction.DeduplicatedFlattenFactory;
 import org.apache.beam.runners.core.construction.EmptyFlattenAsCreateFactory;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource;
@@ -96,6 +97,7 @@ import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.GroupedValues;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -390,25 +392,29 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
   }
 
   private static class ReflectiveOneToOneOverrideFactory<
-          InputT extends PValue,
-          OutputT extends PValue,
-          TransformT extends PTransform<InputT, OutputT>>
-      extends SingleInputOutputOverrideFactory<InputT, OutputT, TransformT> {
-    private final Class<PTransform<InputT, OutputT>> replacement;
+          InputT, OutputT, TransformT extends PTransform<PCollection<InputT>, PCollection<OutputT>>>
+      extends SingleInputOutputOverrideFactory<
+          PCollection<InputT>, PCollection<OutputT>, TransformT> {
+    private final Class<PTransform<PCollection<InputT>, PCollection<OutputT>>>
replacement;
     private final DataflowRunner runner;
 
     private ReflectiveOneToOneOverrideFactory(
-        Class<PTransform<InputT, OutputT>> replacement, DataflowRunner runner)
{
+        Class<PTransform<PCollection<InputT>, PCollection<OutputT>>>
replacement,
+        DataflowRunner runner) {
       this.replacement = replacement;
       this.runner = runner;
     }
 
     @Override
-    public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform)
{
-      return InstanceBuilder.ofType(replacement)
-          .withArg(DataflowRunner.class, runner)
-          .withArg((Class<PTransform<InputT, OutputT>>) transform.getClass(),
transform)
-          .build();
+    public PTransformReplacement<PCollection<InputT>, PCollection<OutputT>>
getReplacementTransform(
+        AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, TransformT>
transform) {
+      PTransform<PCollection<InputT>, PCollection<OutputT>> rep =
+          InstanceBuilder.ofType(replacement)
+              .withArg(DataflowRunner.class, runner)
+              .withArg(
+                  (Class<TransformT>) transform.getTransform().getClass(), transform.getTransform())
+              .build();
+      return PTransformReplacement.of(PTransformReplacements.getSingletonMainInput(transform),
rep);
     }
   }
 
@@ -423,19 +429,18 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
       this.replacement = replacement;
       this.runner = runner;
     }
-    @Override
-    public PTransform<PBegin, PCollection<T>> getReplacementTransform(
-        PTransform<PInput, PCollection<T>> transform) {
-      return InstanceBuilder.ofType(replacement)
-          .withArg(DataflowRunner.class, runner)
-          .withArg(
-              (Class<? super PTransform<PInput, PCollection<T>>>) transform.getClass(),
transform)
-          .build();
-    }
 
     @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<T>, PTransform<PInput, PCollection<T>>>
transform) {
+      PTransform<PInput, PCollection<T>> original = transform.getTransform();
+      return PTransformReplacement.of(
+          transform.getPipeline().begin(),
+          InstanceBuilder.ofType(replacement)
+              .withArg(DataflowRunner.class, runner)
+              .withArg(
+                  (Class<? super PTransform<PInput, PCollection<T>>>) original.getClass(),
original)
+              .build());
     }
 
     @Override
@@ -805,13 +810,11 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
     }
 
     @Override
-    public PTransform<PCollection<T>, PDone> getReplacementTransform(Write<T>
transform) {
-      return new BatchWrite<>(runner, transform);
-    }
-
-    @Override
-    public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline
p) {
-      return (PCollection<T>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform(
+        AppliedPTransform<PCollection<T>, PDone, Write<T>> transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new BatchWrite<>(runner, transform.getTransform()));
     }
 
     @Override
@@ -1295,15 +1298,15 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
           PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K,
OutputT>>,
           Combine.GroupedValues<K, InputT, OutputT>> {
     @Override
-    public PTransform<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K,
OutputT>>>
-        getReplacementTransform(GroupedValues<K, InputT, OutputT> transform) {
-      return new CombineGroupedValues<>(transform);
-    }
-
-    @Override
-    public PCollection<KV<K, Iterable<InputT>>> getInput(
-        Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<KV<K, Iterable<InputT>>>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<KV<K, Iterable<InputT>>>,
PCollection<KV<K, OutputT>>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K,
OutputT>>,
+                    GroupedValues<K, InputT, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new CombineGroupedValues<>(transform.getTransform()));
     }
 
     @Override
@@ -1322,14 +1325,11 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob>
{
     }
 
     @Override
-    public PTransform<PCollection<T>, PDone> getReplacementTransform(
-        PubsubUnboundedSink<T> transform) {
-      return new StreamingPubsubIOWrite<>(runner, transform);
-    }
-
-    @Override
-    public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline
p) {
-      return (PCollection<T>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform(
+        AppliedPTransform<PCollection<T>, PDone, PubsubUnboundedSink<T>>
transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new StreamingPubsubIOWrite<>(runner, transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
index db50cc2..2e50cb5 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
@@ -20,12 +20,15 @@ package org.apache.beam.runners.dataflow;
 
 import java.util.List;
 import org.apache.beam.runners.core.construction.ForwardingPTransform;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi.DisplayData;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.ParDo.SingleOutput;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 
@@ -38,9 +41,15 @@ public class PrimitiveParDoSingleFactory<InputT, OutputT>
     extends SingleInputOutputOverrideFactory<
         PCollection<? extends InputT>, PCollection<OutputT>, ParDo.SingleOutput<InputT,
OutputT>> {
   @Override
-  public PTransform<PCollection<? extends InputT>, PCollection<OutputT>>
getReplacementTransform(
-      ParDo.SingleOutput<InputT, OutputT> transform) {
-    return new ParDoSingle<>(transform);
+  public PTransformReplacement<PCollection<? extends InputT>, PCollection<OutputT>>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<? extends InputT>, PCollection<OutputT>,
+                  SingleOutput<InputT, OutputT>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new ParDoSingle<>(transform.getTransform()));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
index 2e6455d..aa9d9f8 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
@@ -18,8 +18,10 @@
 
 package org.apache.beam.runners.dataflow;
 
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -43,9 +45,13 @@ class ReshuffleOverrideFactory<K, V>
     extends SingleInputOutputOverrideFactory<
         PCollection<KV<K, V>>, PCollection<KV<K, V>>, Reshuffle<K,
V>> {
   @Override
-  public PTransform<PCollection<KV<K, V>>, PCollection<KV<K, V>>>
getReplacementTransform(
-      Reshuffle<K, V> transform) {
-    return new ReshuffleWithOnlyTrigger<>();
+  public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K,
V>>>
+      getReplacementTransform(
+          AppliedPTransform<PCollection<KV<K, V>>, PCollection<KV<K,
V>>, Reshuffle<K, V>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new ReshuffleWithOnlyTrigger<K, V>());
   }
 
   private static class ReshuffleWithOnlyTrigger<K, V>

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
index c407517..eb385de 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
@@ -20,11 +20,13 @@ package org.apache.beam.runners.dataflow;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.dataflow.DataflowRunner.StreamingPCollectionViewWriterFn;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -42,9 +44,15 @@ class StreamingViewOverrides {
       extends SingleInputOutputOverrideFactory<
           PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT,
ViewT>> {
     @Override
-    public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
-        final CreatePCollectionView<ElemT, ViewT> transform) {
-      return new StreamingCreatePCollectionView<>(transform.getView());
+    public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT,
ViewT>>
+                transform) {
+      StreamingCreatePCollectionView<ElemT, ViewT> streamingView =
+          new StreamingCreatePCollectionView<>(transform.getTransform().getView());
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform), streamingView);
     }
 
     private static class StreamingCreatePCollectionView<ElemT, ViewT>

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
index bff46ea..e320036 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
@@ -27,10 +27,11 @@ import java.io.Serializable;
 import java.util.List;
 import org.apache.beam.runners.dataflow.PrimitiveParDoSingleFactory.ParDoSingle;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.View;
@@ -64,17 +65,27 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable {
   public void getReplacementTransformPopulateDisplayData() {
     ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(new ToLongFn());
     DisplayData originalDisplayData = DisplayData.from(originalTransform);
-
-    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacement
=
-        factory.getReplacementTransform(originalTransform);
-    DisplayData replacementDisplayData = DisplayData.from(replacement);
+    PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    AppliedPTransform<
+        PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer,
Long>>
+        application =
+        AppliedPTransform.of(
+            "original",
+            input.expand(),
+            input.apply(originalTransform).expand(),
+            originalTransform,
+            pipeline);
+
+    PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>>
replacement =
+        factory.getReplacementTransform(application);
+    DisplayData replacementDisplayData = DisplayData.from(replacement.getTransform());
 
     assertThat(replacementDisplayData, equalTo(originalDisplayData));
 
     DisplayData primitiveDisplayData =
         Iterables.getOnlyElement(
             DisplayDataEvaluator.create()
-                .displayDataForPrimitiveTransforms(replacement, VarIntCoder.of()));
+                .displayDataForPrimitiveTransforms(replacement.getTransform(), VarIntCoder.of()));
     assertThat(primitiveDisplayData, equalTo(replacementDisplayData));
   }
 
@@ -91,9 +102,21 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable {
     ParDo.SingleOutput<Integer, Long> originalTransform =
         ParDo.of(new ToLongFn()).withSideInputs(sideLong, sideStrings);
 
-    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform
=
-        factory.getReplacementTransform(originalTransform);
-    ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+    PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    AppliedPTransform<
+        PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer,
Long>>
+        application =
+        AppliedPTransform.of(
+            "original",
+            input.expand(),
+            input.apply(originalTransform).expand(),
+            originalTransform,
+            pipeline);
+
+    PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>>
replacementTransform =
+        factory.getReplacementTransform(application);
+    ParDoSingle<Integer, Long> parDoSingle =
+        (ParDoSingle<Integer, Long>) replacementTransform.getTransform();
     assertThat(parDoSingle.getSideInputs(), containsInAnyOrder(sideStrings, sideLong));
   }
 
@@ -101,9 +124,21 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable
{
   public void getReplacementTransformGetFn() {
     DoFn<Integer, Long> originalFn = new ToLongFn();
     ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(originalFn);
-    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform
=
-        factory.getReplacementTransform(originalTransform);
-    ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+    PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    AppliedPTransform<
+            PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer,
Long>>
+        application =
+            AppliedPTransform.of(
+                "original",
+                input.expand(),
+                input.apply(originalTransform).expand(),
+                originalTransform,
+                pipeline);
+
+    PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>>
replacementTransform =
+        factory.getReplacementTransform(application);
+    ParDoSingle<Integer, Long> parDoSingle =
+        (ParDoSingle<Integer, Long>) replacementTransform.getTransform();
 
     assertThat(parDoSingle.getFn(), equalTo(originalTransform.getFn()));
     assertThat(parDoSingle.getFn(), equalTo(originalFn));

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index aacb942..61fcaa9 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -46,6 +46,7 @@ import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.util.ValueWithRecordId;
@@ -244,14 +245,11 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
         implements PTransformOverrideFactory<
             PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> {
       @Override
-      public PTransform<PBegin, PCollection<T>> getReplacementTransform(
-          BoundedReadFromUnboundedSource<T> transform) {
-        return new AdaptedBoundedAsUnbounded<>(transform);
-      }
-
-      @Override
-      public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-        return p.begin();
+      public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+          AppliedPTransform<PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>>
transform) {
+        return PTransformReplacement.of(
+            transform.getPipeline().begin(),
+            new AdaptedBoundedAsUnbounded<T>(transform.getTransform()));
       }
 
       @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index 791166e..1ff4c30 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -33,11 +33,13 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.UserCodeException;
@@ -497,17 +499,18 @@ public class Pipeline {
       void applyReplacement(
           Node original,
           PTransformOverrideFactory<InputT, OutputT, TransformT> replacementFactory)
{
-    PTransform<InputT, OutputT> replacement =
-        replacementFactory.getReplacementTransform((TransformT) original.getTransform());
-    if (replacement == original.getTransform()) {
+    PTransformReplacement<InputT, OutputT> replacement =
+        replacementFactory.getReplacementTransform(
+            (AppliedPTransform<InputT, OutputT, TransformT>) original.toAppliedPTransform());
+    if (replacement.getTransform() == original.getTransform()) {
       return;
     }
-    InputT originalInput = replacementFactory.getInput(original.getInputs(), this);
+    InputT originalInput = replacement.getInput();
 
     LOG.debug("Replacing {} with {}", original, replacement);
-    transforms.replaceNode(original, originalInput, replacement);
+    transforms.replaceNode(original, originalInput, replacement.getTransform());
     try {
-      OutputT newOutput = replacement.expand(originalInput);
+      OutputT newOutput = replacement.getTransform().expand(originalInput);
       Map<PValue, ReplacementOutput> originalToReplacement =
           replacementFactory.mapOutputs(original.getOutputs(), newOutput);
       // Ensure the internal TransformHierarchy data structures are consistent.

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
index 57cba50..786c61c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
@@ -21,9 +21,9 @@ package org.apache.beam.sdk.runners;
 
 import com.google.auto.value.AutoValue;
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
@@ -41,14 +41,11 @@ public interface PTransformOverrideFactory<
     OutputT extends POutput,
     TransformT extends PTransform<? super InputT, OutputT>> {
   /**
-   * Returns a {@link PTransform} that produces equivalent output to the provided transform.
+   * Returns a {@link PTransform} that produces equivalent output to the provided {@link
+   * AppliedPTransform transform}.
    */
-  PTransform<InputT, OutputT> getReplacementTransform(TransformT transform);
-
-  /**
-   * Returns the composite type that replacement transforms consumed from an equivalent expansion.
-   */
-  InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p);
+  PTransformReplacement<InputT, OutputT> getReplacementTransform(
+      AppliedPTransform<InputT, OutputT, TransformT> transform);
 
   /**
    * Returns a {@link Map} from the expanded values in {@code newOutput} to the values produced
by
@@ -56,7 +53,25 @@ public interface PTransformOverrideFactory<
    */
   Map<PValue, ReplacementOutput> mapOutputs(Map<TupleTag<?>, PValue> outputs,
OutputT newOutput);
 
-  /** A mapping between original {@link TaggedPValue} outputs and their replacements. */
+  /**
+   * A {@link PTransform} that replaces an {@link AppliedPTransform}, and the input required
to
+   * do so. The input must be constructed from the expanded form, as the transform may not
have
+   * originally been applied within this process or from within a Java SDK.
+   */
+  @AutoValue
+  abstract class PTransformReplacement<InputT extends PInput, OutputT extends POutput>
{
+    public static <InputT extends PInput, OutputT extends POutput>
+        PTransformReplacement<InputT, OutputT> of(
+            InputT input, PTransform<InputT, OutputT> transform) {
+      return new AutoValue_PTransformOverrideFactory_PTransformReplacement(input, transform);
+    }
+    public abstract InputT getInput();
+    public abstract PTransform<InputT, OutputT> getTransform();
+  }
+
+  /**
+   * A mapping between original {@link TaggedPValue} outputs and their replacements.
+   */
   @AutoValue
   abstract class ReplacementOutput {
     public static ReplacementOutput of(TaggedPValue original, TaggedPValue replacement) {

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
index 8d99a62..bdb61b8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
@@ -31,6 +31,11 @@ import org.apache.beam.sdk.values.TupleTag;
  *
  * <p>For internal use.
  *
+ * <p>Inputs and outputs are stored in their expanded forms, as the condensed form
of a composite
+ * {@link PInput} or {@link POutput} is a language-specific concept, and {@link AppliedPTransform}
+ * represents a possibly cross-language transform for which no appropriate composite type
exists
+ * in the Java SDK.
+ *
  * @param <InputT>     transform input type
  * @param <OutputT>    transform output type
  * @param <TransformT> transform type

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
index 6ce016d..75cabf2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
@@ -406,16 +406,10 @@ public class PipelineTest {
     class ReplacementOverrideFactory
         implements PTransformOverrideFactory<
             PCollection<String>, PCollection<Long>, OriginalTransform> {
-
       @Override
-      public PTransform<PCollection<String>, PCollection<Long>> getReplacementTransform(
-          OriginalTransform transform) {
-        return new ReplacementTransform();
-      }
-
-      @Override
-      public PCollection<String> getInput(Map<TupleTag<?>, PValue> inputs,
Pipeline p) {
-        return originalInput;
+      public PTransformReplacement<PCollection<String>, PCollection<Long>>
getReplacementTransform(
+          AppliedPTransform<PCollection<String>, PCollection<Long>, OriginalTransform>
transform) {
+        return PTransformReplacement.of(originalInput, new ReplacementTransform());
       }
 
       @Override
@@ -464,14 +458,9 @@ public class PipelineTest {
   static class BoundedCountingInputOverride
       implements PTransformOverrideFactory<PBegin, PCollection<Long>, BoundedCountingInput>
{
     @Override
-    public PTransform<PBegin, PCollection<Long>> getReplacementTransform(
-        BoundedCountingInput transform) {
-      return Create.of(0L);
-    }
-
-    @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<Long>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<Long>, BoundedCountingInput> transform)
{
+      return PTransformReplacement.of(transform.getPipeline().begin(), Create.of(0L));
     }
 
     @Override
@@ -489,15 +478,11 @@ public class PipelineTest {
   }
   static class UnboundedCountingInputOverride
       implements PTransformOverrideFactory<PBegin, PCollection<Long>, UnboundedCountingInput>
{
-    @Override
-    public PTransform<PBegin, PCollection<Long>> getReplacementTransform(
-        UnboundedCountingInput transform) {
-      return CountingInput.upTo(100L);
-    }
 
     @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<Long>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<Long>, UnboundedCountingInput>
transform) {
+      return PTransformReplacement.of(transform.getPipeline().begin(), CountingInput.upTo(100L));
     }
 
     @Override


Mime
View raw message