beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [11/50] [abbrv] beam git commit: Update Signature of PTransformOverrideFactory
Date Tue, 25 Apr 2017 17:30:06 GMT
Update Signature of PTransformOverrideFactory

This enables replacements to be reobtained with the entire transform
that is being replaced.

This is required when Side Inputs are part of the input of the
PTransform Application, as PTransforms are not applied to their side
inputs.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f3b49605
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f3b49605
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f3b49605

Branch: refs/heads/jstorm-runner
Commit: f3b496053d2596ee1b2de55f6da055b478a0d6d3
Parents: 3c2b855
Author: Thomas Groh <tgroh@google.com>
Authored: Wed Mar 29 15:23:21 2017 -0700
Committer: Thomas Groh <tgroh@google.com>
Committed: Fri Apr 14 16:52:03 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/runners/apex/ApexRunner.java    |  32 +++--
 .../DeduplicatedFlattenFactory.java             |  63 +++++----
 .../EmptyFlattenAsCreateFactory.java            |  20 ++-
 .../core/construction/PTransformMatchers.java   |   2 -
 .../construction/PTransformReplacements.java    |  69 ++++++++++
 .../core/construction/PrimitiveCreate.java      |  13 +-
 .../SingleInputOutputOverrideFactory.java       |   9 +-
 .../UnsupportedOverrideFactory.java             |  14 +-
 .../DeduplicatedFlattenFactoryTest.java         |  18 +--
 .../EmptyFlattenAsCreateFactoryTest.java        |  36 ++++-
 .../PTransformReplacementsTest.java             | 131 +++++++++++++++++++
 .../SingleInputOutputOverrideFactoryTest.java   |  31 ++---
 .../UnsupportedOverrideFactoryTest.java         |  11 +-
 ...ectGBKIntoKeyedWorkItemsOverrideFactory.java |  16 ++-
 .../direct/DirectGroupByKeyOverrideFactory.java |  14 +-
 .../direct/ParDoMultiOverrideFactory.java       |  22 ++--
 .../direct/TestStreamEvaluatorFactory.java      |  14 +-
 .../runners/direct/ViewOverrideFactory.java     |  18 +--
 .../direct/WriteWithShardingFactory.java        |  16 +--
 .../DirectGroupByKeyOverrideFactoryTest.java    |  12 +-
 .../direct/ParDoMultiOverrideFactoryTest.java   |  45 -------
 .../direct/TestStreamEvaluatorFactoryTest.java  |  12 --
 .../runners/direct/ViewOverrideFactoryTest.java |  42 ++++--
 .../direct/WriteWithShardingFactoryTest.java    |  23 ++--
 .../flink/FlinkStreamingPipelineTranslator.java |  56 ++++----
 .../dataflow/BatchStatefulParDoOverrides.java   |  42 +++---
 .../runners/dataflow/BatchViewOverrides.java    |  17 ++-
 .../beam/runners/dataflow/DataflowRunner.java   |  92 ++++++-------
 .../dataflow/PrimitiveParDoSingleFactory.java   |  15 ++-
 .../dataflow/ReshuffleOverrideFactory.java      |  12 +-
 .../dataflow/StreamingViewOverrides.java        |  14 +-
 .../PrimitiveParDoSingleFactoryTest.java        |  59 +++++++--
 .../beam/runners/spark/TestSparkRunner.java     |  14 +-
 .../main/java/org/apache/beam/sdk/Pipeline.java |  15 ++-
 .../sdk/runners/PTransformOverrideFactory.java  |  33 +++--
 .../beam/sdk/transforms/AppliedPTransform.java  |   5 +
 .../java/org/apache/beam/sdk/PipelineTest.java  |  33 ++---
 37 files changed, 675 insertions(+), 415 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
index 1c99f8d..1c845c6 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
@@ -39,6 +39,7 @@ import org.apache.apex.api.Launcher.AppHandle;
 import org.apache.apex.api.Launcher.LaunchMode;
 import org.apache.beam.runners.apex.translation.ApexPipelineTranslator;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.PrimitiveCreate;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.Pipeline;
@@ -49,6 +50,7 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PipelineRunner;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView;
 import org.apache.beam.sdk.transforms.Create;
@@ -258,9 +260,15 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> {
             PCollection<InputT>, PCollectionView<OutputT>,
             Combine.GloballyAsSingletonView<InputT, OutputT>> {
       @Override
-      public PTransform<PCollection<InputT>, PCollectionView<OutputT>> getReplacementTransform(
-          GloballyAsSingletonView<InputT, OutputT> transform) {
-        return new StreamingCombineGloballyAsSingletonView<>(transform);
+      public PTransformReplacement<PCollection<InputT>, PCollectionView<OutputT>>
+          getReplacementTransform(
+              AppliedPTransform<
+                      PCollection<InputT>, PCollectionView<OutputT>,
+                      GloballyAsSingletonView<InputT, OutputT>>
+                  transform) {
+        return PTransformReplacement.of(
+            PTransformReplacements.getSingletonMainInput(transform),
+            new StreamingCombineGloballyAsSingletonView<>(transform.getTransform()));
       }
     }
   }
@@ -321,9 +329,11 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> {
         extends SingleInputOutputOverrideFactory<
             PCollection<T>, PCollectionView<T>, View.AsSingleton<T>> {
       @Override
-      public PTransform<PCollection<T>, PCollectionView<T>> getReplacementTransform(
-          AsSingleton<T> transform) {
-        return new StreamingViewAsSingleton<>(transform);
+      public PTransformReplacement<PCollection<T>, PCollectionView<T>> getReplacementTransform(
+          AppliedPTransform<PCollection<T>, PCollectionView<T>, AsSingleton<T>> transform) {
+        return PTransformReplacement.of(
+            PTransformReplacements.getSingletonMainInput(transform),
+            new StreamingViewAsSingleton<>(transform.getTransform()));
       }
     }
   }
@@ -352,9 +362,13 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> {
         extends SingleInputOutputOverrideFactory<
             PCollection<T>, PCollectionView<Iterable<T>>, View.AsIterable<T>> {
       @Override
-      public PTransform<PCollection<T>, PCollectionView<Iterable<T>>> getReplacementTransform(
-          AsIterable<T> transform) {
-        return new StreamingViewAsIterable<>();
+      public PTransformReplacement<PCollection<T>, PCollectionView<Iterable<T>>>
+          getReplacementTransform(
+              AppliedPTransform<PCollection<T>, PCollectionView<Iterable<T>>, AsIterable<T>>
+                  transform) {
+        return PTransformReplacement.of(
+            PTransformReplacements.getSingletonMainInput(transform),
+            new StreamingViewAsIterable<T>());
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
index c12c548..13e7593 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
@@ -18,10 +18,12 @@
 
 package org.apache.beam.runners.core.construction;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.Flatten.PCollections;
@@ -47,32 +49,11 @@ public class DeduplicatedFlattenFactory<T>
   private DeduplicatedFlattenFactory() {}
 
   @Override
-  public PTransform<PCollectionList<T>, PCollection<T>> getReplacementTransform(
-      PCollections<T> transform) {
-    return new PTransform<PCollectionList<T>, PCollection<T>>() {
-      @Override
-      public PCollection<T> expand(PCollectionList<T> input) {
-        Map<PCollection<T>, Integer> instances = new HashMap<>();
-        for (PCollection<T> pCollection : input.getAll()) {
-          int existing = instances.get(pCollection) == null ? 0 : instances.get(pCollection);
-          instances.put(pCollection, existing + 1);
-        }
-        PCollectionList<T> output = PCollectionList.empty(input.getPipeline());
-        for (Map.Entry<PCollection<T>, Integer> instanceEntry : instances.entrySet()) {
-          if (instanceEntry.getValue().equals(1)) {
-            output = output.and(instanceEntry.getKey());
-          } else {
-            String duplicationName = String.format("Multiply %s", instanceEntry.getKey().getName());
-            PCollection<T> duplicated =
-                instanceEntry
-                    .getKey()
-                    .apply(duplicationName, ParDo.of(new DuplicateFn<T>(instanceEntry.getValue())));
-            output = output.and(duplicated);
-          }
-        }
-        return output.apply(Flatten.<T>pCollections());
-      }
-    };
+  public PTransformReplacement<PCollectionList<T>, PCollection<T>> getReplacementTransform(
+      AppliedPTransform<PCollectionList<T>, PCollection<T>, PCollections<T>> transform) {
+    return PTransformReplacement.of(
+        getInput(transform.getInputs(), transform.getPipeline()),
+        new FlattenWithoutDuplicateInputs<T>());
   }
 
   /**
@@ -80,8 +61,7 @@ public class DeduplicatedFlattenFactory<T>
    *
    * <p>The input {@link PCollectionList} that is constructed will have the same values in the same
    */
-  @Override
-  public PCollectionList<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
+  private PCollectionList<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
     PCollectionList<T> pCollections = PCollectionList.empty(p);
     for (PValue input : inputs.values()) {
       PCollection<T> pcollection = (PCollection<T>) input;
@@ -96,6 +76,33 @@ public class DeduplicatedFlattenFactory<T>
     return ReplacementOutputs.singleton(outputs, newOutput);
   }
 
+  @VisibleForTesting
+  static class FlattenWithoutDuplicateInputs<T>
+      extends PTransform<PCollectionList<T>, PCollection<T>> {
+    @Override
+    public PCollection<T> expand(PCollectionList<T> input) {
+      Map<PCollection<T>, Integer> instances = new HashMap<>();
+      for (PCollection<T> pCollection : input.getAll()) {
+        int existing = instances.get(pCollection) == null ? 0 : instances.get(pCollection);
+        instances.put(pCollection, existing + 1);
+      }
+      PCollectionList<T> output = PCollectionList.empty(input.getPipeline());
+      for (Map.Entry<PCollection<T>, Integer> instanceEntry : instances.entrySet()) {
+        if (instanceEntry.getValue().equals(1)) {
+          output = output.and(instanceEntry.getKey());
+        } else {
+          String duplicationName = String.format("Multiply %s", instanceEntry.getKey().getName());
+          PCollection<T> duplicated =
+              instanceEntry
+                  .getKey()
+                  .apply(duplicationName, ParDo.of(new DuplicateFn<T>(instanceEntry.getValue())));
+          output = output.and(duplicated);
+        }
+      }
+      return output.apply(Flatten.<T>pCollections());
+    }
+  }
+
   private static class DuplicateFn<T> extends DoFn<T, T> {
     private final int numTimes;
 

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
index 936bc08..a6982d4 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
@@ -21,11 +21,12 @@ package org.apache.beam.runners.core.construction;
 import static com.google.common.base.Preconditions.checkArgument;
 
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.Flatten.PCollections;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
@@ -49,20 +50,15 @@ public class EmptyFlattenAsCreateFactory<T>
   private EmptyFlattenAsCreateFactory() {}
 
   @Override
-  public PTransform<PCollectionList<T>, PCollection<T>> getReplacementTransform(
-      Flatten.PCollections<T> transform) {
-    return new CreateEmptyFromList<>();
-  }
-
-  @Override
-  public PCollectionList<T> getInput(
-      Map<TupleTag<?>, PValue> inputs, Pipeline p) {
+  public PTransformReplacement<PCollectionList<T>, PCollection<T>> getReplacementTransform(
+      AppliedPTransform<PCollectionList<T>, PCollection<T>, PCollections<T>> transform) {
     checkArgument(
-        inputs.isEmpty(),
+        transform.getInputs().isEmpty(),
         "Unexpected nonempty input %s for %s",
-        inputs,
+        transform.getInputs(),
         getClass().getSimpleName());
-    return PCollectionList.empty(p);
+    return PTransformReplacement.of(
+        PCollectionList.<T>empty(transform.getPipeline()), new CreateEmptyFromList<T>());
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
index 94ec38c..09946bc 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
@@ -52,8 +52,6 @@ public class PTransformMatchers {
   /**
    * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the class of the
    * {@link PTransform} is equal to the {@link Class} provided ot this matcher.
-   * @param clazz
-   * @return
    */
   public static PTransformMatcher classEqualTo(Class<? extends PTransform> clazz) {
     return new EqualClassPTransformMatcher(clazz);

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
new file mode 100644
index 0000000..72a3425
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.util.Map;
+import java.util.Set;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+
+/**
+ */
+public class PTransformReplacements {
+  /**
+   * Gets the singleton input of an {@link AppliedPTransform}, ignoring any additional inputs
+   * returned by {@link PTransform#getAdditionalInputs()}.
+   */
+  public static <T> PCollection<T> getSingletonMainInput(
+      AppliedPTransform<? extends PCollection<? extends T>, ?, ?> application) {
+    return getSingletonMainInput(
+        application.getInputs(), application.getTransform().getAdditionalInputs().keySet());
+  }
+
+  private static <T> PCollection<T> getSingletonMainInput(
+      Map<TupleTag<?>, PValue> inputs, Set<TupleTag<?>> ignoredTags) {
+    PCollection<T> mainInput = null;
+    for (Map.Entry<TupleTag<?>, PValue> input : inputs.entrySet()) {
+      if (!ignoredTags.contains(input.getKey())) {
+        checkArgument(
+            mainInput == null,
+            "Got multiple inputs that are not additional inputs for a "
+                + "singleton main input: %s and %s",
+            mainInput,
+            input.getValue());
+        checkArgument(
+            input.getValue() instanceof PCollection,
+            "Unexpected input type %s",
+            input.getValue().getClass());
+        mainInput = (PCollection<T>) input.getValue();
+      }
+    }
+    checkArgument(
+        mainInput != null,
+        "No main input found in inputs: Inputs %s, Side Input tags %s",
+        inputs,
+        ignoredTags);
+    return mainInput;
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
index 9335f3a..5a2140b 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
@@ -19,8 +19,8 @@
 package org.apache.beam.runners.core.construction;
 
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Create.Values;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -57,13 +57,10 @@ public class PrimitiveCreate<T> extends PTransform<PBegin, PCollection<T>> {
   public static class Factory<T>
       implements PTransformOverrideFactory<PBegin, PCollection<T>, Values<T>> {
     @Override
-    public PTransform<PBegin, PCollection<T>> getReplacementTransform(Values<T> transform) {
-      return new PrimitiveCreate<>(transform);
-    }
-
-    @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<T>, Values<T>> transform) {
+      return PTransformReplacement.of(
+          transform.getPipeline().begin(), new PrimitiveCreate<T>(transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
index 6d0d571..7a59c1c 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
@@ -18,9 +18,7 @@
 
 package org.apache.beam.runners.core.construction;
 
-import com.google.common.collect.Iterables;
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PValue;
@@ -28,7 +26,7 @@ import org.apache.beam.sdk.values.TupleTag;
 
 /**
  * A {@link PTransformOverrideFactory} which consumes from a {@link PValue} and produces a
- * {@link PValue}. {@link #getInput(Map, Pipeline)} and {@link #mapOutputs(Map, PValue)} are
+ * {@link PValue}. {@link #mapOutputs(Map, PValue)} is
  * implemented.
  */
 public abstract class SingleInputOutputOverrideFactory<
@@ -37,11 +35,6 @@ public abstract class SingleInputOutputOverrideFactory<
         TransformT extends PTransform<InputT, OutputT>>
     implements PTransformOverrideFactory<InputT, OutputT, TransformT> {
   @Override
-  public final InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (InputT) Iterables.getOnlyElement(inputs.values());
-  }
-
-  @Override
   public final Map<PValue, ReplacementOutput> mapOutputs(
       Map<TupleTag<?>, PValue> outputs, OutputT newOutput) {
     return ReplacementOutputs.singleton(outputs, newOutput);

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
index 7b9d704..efafa33 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
@@ -19,8 +19,8 @@
 package org.apache.beam.runners.core.construction;
 
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
@@ -29,8 +29,8 @@ import org.apache.beam.sdk.values.TupleTag;
 
 /**
  * A {@link PTransformOverrideFactory} that throws an exception when a call to
- * {@link #getReplacementTransform(PTransform)} is made. This is for {@link PTransform PTransforms}
- * which are not supported by a runner.
+ * {@link #getReplacementTransform(AppliedPTransform)} is made. This is for
+ * {@link PTransform PTransforms} which are not supported by a runner.
  */
 public final class UnsupportedOverrideFactory<
         InputT extends PInput,
@@ -54,12 +54,8 @@ public final class UnsupportedOverrideFactory<
   }
 
   @Override
-  public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform) {
-    throw new UnsupportedOperationException(message);
-  }
-
-  @Override
-  public InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
+  public PTransformReplacement<InputT, OutputT> getReplacementTransform(
+      AppliedPTransform<InputT, OutputT, TransformT> transform) {
     throw new UnsupportedOperationException(message);
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
index 14aa1e6..4e08c21 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
@@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertThat;
 
+import org.apache.beam.runners.core.construction.DeduplicatedFlattenFactory.FlattenWithoutDuplicateInputs;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.runners.TransformHierarchy;
@@ -56,7 +57,7 @@ public class DeduplicatedFlattenFactoryTest {
   @Test
   public void duplicatesInsertsMultipliers() {
     PTransform<PCollectionList<String>, PCollection<String>> replacement =
-        factory.getReplacementTransform(Flatten.<String>pCollections());
+        new DeduplicatedFlattenFactory.FlattenWithoutDuplicateInputs<>();
     final PCollectionList<String> inputList =
         PCollectionList.of(first).and(second).and(first).and(first);
     inputList.apply(replacement);
@@ -74,10 +75,10 @@ public class DeduplicatedFlattenFactoryTest {
   @Test
   @Category(NeedsRunner.class)
   public void testOverride() {
-    PTransform<PCollectionList<String>, PCollection<String>> replacement =
-        factory.getReplacementTransform(Flatten.<String>pCollections());
     final PCollectionList<String> inputList =
         PCollectionList.of(first).and(second).and(first).and(first);
+    PTransform<PCollectionList<String>, PCollection<String>> replacement =
+        new FlattenWithoutDuplicateInputs<>();
     PCollection<String> flattened = inputList.apply(replacement);
 
     PAssert.that(flattened).containsInAnyOrder("one", "two", "one", "one");
@@ -85,21 +86,12 @@ public class DeduplicatedFlattenFactoryTest {
   }
 
   @Test
-  public void inputReconstruction() {
-    final PCollectionList<String> inputList =
-        PCollectionList.of(first).and(second).and(first).and(first);
-
-    assertThat(factory.getInput(inputList.expand(), pipeline), equalTo(inputList));
-  }
-
-  @Test
   public void outputMapping() {
     final PCollectionList<String> inputList =
         PCollectionList.of(first).and(second).and(first).and(first);
     PCollection<String> original =
         inputList.apply(Flatten.<String>pCollections());
-    PCollection<String> replacement =
-        inputList.apply(factory.getReplacementTransform(Flatten.<String>pCollections()));
+    PCollection<String> replacement = inputList.apply(new FlattenWithoutDuplicateInputs<String>());
 
     assertThat(
         factory.mapOutputs(original.expand(), replacement),

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
index 90bbee7..ae2d0a9 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
@@ -18,17 +18,20 @@
 
 package org.apache.beam.runners.core.construction;
 
-import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.emptyIterable;
 import static org.junit.Assert.assertThat;
 
 import java.util.Collections;
 import java.util.Map;
 import org.apache.beam.sdk.io.CountingInput;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.Flatten.PCollections;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PValue;
@@ -54,8 +57,15 @@ public class EmptyFlattenAsCreateFactoryTest {
 
   @Test
   public void getInputEmptySucceeds() {
-    assertThat(
-        factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), pipeline).size(), equalTo(0));
+    PTransformReplacement<PCollectionList<Long>, PCollection<Long>> replacement =
+        factory.getReplacementTransform(
+            AppliedPTransform.<PCollectionList<Long>, PCollection<Long>, PCollections<Long>>of(
+                "nonEmptyInput",
+                Collections.<TupleTag<?>, PValue>emptyMap(),
+                Collections.<TupleTag<?>, PValue>emptyMap(),
+                Flatten.<Long>pCollections(),
+                pipeline));
+    assertThat(replacement.getInput().getAll(), emptyIterable());
   }
 
   @Test
@@ -66,7 +76,13 @@ public class EmptyFlattenAsCreateFactoryTest {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage(nonEmpty.expand().toString());
     thrown.expectMessage(EmptyFlattenAsCreateFactory.class.getSimpleName());
-    factory.getInput(nonEmpty.expand(), pipeline);
+    factory.getReplacementTransform(
+        AppliedPTransform.<PCollectionList<Long>, PCollection<Long>, Flatten.PCollections<Long>>of(
+            "nonEmptyInput",
+            nonEmpty.expand(),
+            Collections.<TupleTag<?>, PValue>emptyMap(),
+            Flatten.<Long>pCollections(),
+            pipeline));
   }
 
   @Test
@@ -89,7 +105,17 @@ public class EmptyFlattenAsCreateFactoryTest {
   public void testOverride() {
     PCollectionList<Long> empty = PCollectionList.empty(pipeline);
     PCollection<Long> emptyFlattened =
-        empty.apply(factory.getReplacementTransform(Flatten.<Long>pCollections()));
+        empty.apply(
+            factory
+                .getReplacementTransform(
+                    AppliedPTransform
+                        .<PCollectionList<Long>, PCollection<Long>, Flatten.PCollections<Long>>of(
+                            "nonEmptyInput",
+                            Collections.<TupleTag<?>, PValue>emptyMap(),
+                            Collections.<TupleTag<?>, PValue>emptyMap(),
+                            Flatten.<Long>pCollections(),
+                            pipeline))
+                .getTransform());
     PAssert.that(emptyFlattened).empty();
     pipeline.run();
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java
new file mode 100644
index 0000000..b065617
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.Collections;
+import org.apache.beam.sdk.io.CountingInput;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link PTransformReplacements}.
+ */
+@RunWith(JUnit4.class)
+public class PTransformReplacementsTest {
+  @Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
+  @Rule public ExpectedException thrown = ExpectedException.none();
+  private PCollection<Long> mainInput = pipeline.apply(CountingInput.unbounded());
+  private PCollectionView<String> sideInput =
+      pipeline.apply(Create.of("foo")).apply(View.<String>asSingleton());
+
+  private PCollection<Long> output = mainInput.apply(ParDo.of(new TestDoFn()));
+
+  @Test
+  public void getMainInputSingleOutputSingleInput() {
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), mainInput),
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()),
+            pipeline);
+    PCollection<Long> input = PTransformReplacements.getSingletonMainInput(application);
+    assertThat(input, equalTo(mainInput));
+  }
+
+  @Test
+  public void getMainInputSingleOutputSideInputs() {
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            ImmutableMap.<TupleTag<?>, PValue>builder()
+                .put(new TupleTag<Long>(), mainInput)
+                .put(sideInput.getTagInternal(), sideInput.getPCollection())
+                .build(),
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()).withSideInputs(sideInput),
+            pipeline);
+    PCollection<Long> input = PTransformReplacements.getSingletonMainInput(application);
+    assertThat(input, equalTo(mainInput));
+  }
+
+  @Test
+  public void getMainInputExtraMainInputsThrows() {
+    PCollection<Long> notInParDo = pipeline.apply("otherPCollection", Create.of(1L, 2L, 3L));
+    ImmutableMap<TupleTag<?>, PValue> inputs =
+        ImmutableMap.<TupleTag<?>, PValue>builder()
+            .putAll(mainInput.expand())
+            // Not represnted as an input
+            .put(new TupleTag<Long>(), notInParDo)
+            .put(sideInput.getTagInternal(), sideInput.getPCollection())
+            .build();
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            inputs,
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()).withSideInputs(sideInput),
+            pipeline);
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("multiple inputs");
+    thrown.expectMessage("not additional inputs");
+    thrown.expectMessage(mainInput.toString());
+    thrown.expectMessage(notInParDo.toString());
+    PTransformReplacements.getSingletonMainInput(application);
+  }
+
+  @Test
+  public void getMainInputNoMainInputsThrows() {
+    ImmutableMap<TupleTag<?>, PValue> inputs =
+        ImmutableMap.<TupleTag<?>, PValue>builder()
+            .put(sideInput.getTagInternal(), sideInput.getPCollection())
+            .build();
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            inputs,
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()).withSideInputs(sideInput),
+            pipeline);
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("No main input");
+    PTransformReplacements.getSingletonMainInput(application);
+  }
+
+  private static class TestDoFn extends DoFn<Long, Long> {
+    @ProcessElement public void process(ProcessContext context) {}
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
index 07352f5..acca5cd 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
@@ -24,9 +24,9 @@ import java.io.Serializable;
 import java.util.Map;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
@@ -55,9 +55,15 @@ public class SingleInputOutputOverrideFactoryTest implements Serializable {
               PCollection<? extends Integer>, PCollection<Integer>,
               MapElements<Integer, Integer>>() {
             @Override
-            public PTransform<PCollection<? extends Integer>, PCollection<Integer>>
-                getReplacementTransform(MapElements<Integer, Integer> transform) {
-              return transform;
+            public PTransformReplacement<PCollection<? extends Integer>, PCollection<Integer>>
+                getReplacementTransform(
+                    AppliedPTransform<
+                            PCollection<? extends Integer>, PCollection<Integer>,
+                            MapElements<Integer, Integer>>
+                        transform) {
+              return PTransformReplacement.of(
+                  PTransformReplacements.getSingletonMainInput(transform),
+                  transform.getTransform());
             }
           };
 
@@ -69,23 +75,6 @@ public class SingleInputOutputOverrideFactoryTest implements Serializable {
     };
 
   @Test
-  public void testGetInput() {
-    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
-    assertThat(
-        factory.getInput(input.expand(), pipeline),
-        Matchers.<PCollection<? extends Integer>>equalTo(input));
-  }
-
-  @Test
-  public void testGetInputMultipleInputsFails() {
-    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
-    PCollection<Integer> otherInput = pipeline.apply("OtherCreate", Create.of(1, 2, 3));
-
-    thrown.expect(IllegalArgumentException.class);
-    factory.getInput(PCollectionList.of(input).and(otherInput).expand(), pipeline);
-  }
-
-  @Test
   public void testMapOutputs() {
     PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
     PCollection<Integer> output = input.apply("Map", MapElements.via(fn));

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
index 81ce00d..6d3b263 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
@@ -19,9 +19,7 @@
 package org.apache.beam.runners.core.construction;
 
 import java.util.Collections;
-import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
@@ -47,14 +45,7 @@ public class UnsupportedOverrideFactoryTest {
   public void getReplacementTransformThrows() {
     thrown.expect(UnsupportedOperationException.class);
     thrown.expectMessage(message);
-    factory.getReplacementTransform(Create.empty(VoidCoder.of()));
-  }
-
-  @Test
-  public void getInputThrows() {
-    thrown.expect(UnsupportedOperationException.class);
-    thrown.expectMessage(message);
-    factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), pipeline);
+    factory.getReplacementTransform(null);
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
index bb90a6c..1120243 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
@@ -19,8 +19,9 @@ package org.apache.beam.runners.direct;
 
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.SplittableParDo.GBKIntoKeyedWorkItems;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 
@@ -33,8 +34,15 @@ class DirectGBKIntoKeyedWorkItemsOverrideFactory<KeyT, InputT>
         PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>,
         GBKIntoKeyedWorkItems<KeyT, InputT>> {
   @Override
-  public PTransform<PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>>
-      getReplacementTransform(GBKIntoKeyedWorkItems<KeyT, InputT> transform) {
-    return new DirectGroupByKey.DirectGroupByKeyOnly<>();
+  public PTransformReplacement<
+          PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>,
+                  GBKIntoKeyedWorkItems<KeyT, InputT>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new DirectGroupByKey.DirectGroupByKeyOnly<KeyT, InputT>());
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
index f3b718f..4eb0363 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
@@ -17,10 +17,11 @@
  */
 package org.apache.beam.runners.direct;
 
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 
@@ -29,8 +30,13 @@ final class DirectGroupByKeyOverrideFactory<K, V>
     extends SingleInputOutputOverrideFactory<
         PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> {
   @Override
-  public PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> getReplacementTransform(
-      GroupByKey<K, V> transform) {
-    return new DirectGroupByKey<>(transform);
+  public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new DirectGroupByKey<>(transform.getTransform()));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 366777b..b08aa8e 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -19,18 +19,18 @@ package org.apache.beam.runners.direct;
 
 import static com.google.common.base.Preconditions.checkState;
 
-import com.google.common.collect.Iterables;
 import java.util.Map;
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.KeyedWorkItemCoder;
 import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -62,8 +62,18 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
     implements PTransformOverrideFactory<
         PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> {
   @Override
+  public PTransformReplacement<PCollection<? extends InputT>, PCollectionTuple>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        getReplacementTransform(transform.getTransform()));
+  }
+
   @SuppressWarnings("unchecked")
-  public PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform(
+  private PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform(
       MultiOutput<InputT, OutputT> transform) {
 
     DoFn<InputT, OutputT> fn = transform.getFn();
@@ -84,12 +94,6 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
   }
 
   @Override
-  public PCollection<? extends InputT> getInput(
-      Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs.values());
-  }
-
-  @Override
   public Map<PValue, ReplacementOutput> mapOutputs(
       Map<TupleTag<?>, PValue> outputs, PCollectionTuple newOutput) {
     return ReplacementOutputs.tagged(outputs, newOutput);

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
index 6e0a4fc..cba754e 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
@@ -31,7 +31,6 @@ import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.testing.TestStream.ElementEvent;
@@ -170,14 +169,11 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory {
     }
 
     @Override
-    public PTransform<PBegin, PCollection<T>> getReplacementTransform(
-        TestStream<T> transform) {
-      return new DirectTestStream<>(runner, transform);
-    }
-
-    @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<T>, TestStream<T>> transform) {
+      return PTransformReplacement.of(
+          transform.getPipeline().begin(),
+          new DirectTestStream<T>(runner, transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
index 52dc329..d4fd18f 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
@@ -18,14 +18,14 @@
 
 package org.apache.beam.runners.direct;
 
-import com.google.common.collect.Iterables;
 import java.util.Collections;
 import java.util.Map;
 import org.apache.beam.runners.core.construction.ForwardingPTransform;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.Values;
@@ -43,15 +43,15 @@ import org.apache.beam.sdk.values.TupleTag;
 class ViewOverrideFactory<ElemT, ViewT>
     implements PTransformOverrideFactory<
         PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> {
-  @Override
-  public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
-      CreatePCollectionView<ElemT, ViewT> transform) {
-    return new GroupAndWriteView<>(transform);
-  }
 
   @Override
-  public PCollection<ElemT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (PCollection<ElemT>) Iterables.getOnlyElement(inputs.values());
+  public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
+      AppliedPTransform<
+              PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>>
+          transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new GroupAndWriteView<>(transform.getTransform()));
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
index b3f92ab..a23ab94 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -21,14 +21,14 @@ package org.apache.beam.runners.direct;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Supplier;
 import com.google.common.base.Suppliers;
-import com.google.common.collect.Iterables;
 import java.io.Serializable;
 import java.util.Collections;
 import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -53,14 +53,12 @@ class WriteWithShardingFactory<InputT>
   @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3;
 
   @Override
-  public PTransform<PCollection<InputT>, PDone> getReplacementTransform(
-      Write<InputT> transform) {
-    return transform.withSharding(new LogElementShardsWithDrift<InputT>());
-  }
+  public PTransformReplacement<PCollection<InputT>, PDone> getReplacementTransform(
+      AppliedPTransform<PCollection<InputT>, PDone, Write<InputT>> transform) {
 
-  @Override
-  public PCollection<InputT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (PCollection<InputT>) Iterables.getOnlyElement(inputs.values());
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        transform.getTransform().withSharding(new LogElementShardsWithDrift<InputT>()));
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
index c9fdda0..28fef4c 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
@@ -23,8 +23,11 @@ import static org.junit.Assert.assertThat;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.hamcrest.Matchers;
@@ -45,7 +48,12 @@ public class DirectGroupByKeyOverrideFactoryTest {
         p.apply(
             Create.of(KV.of("foo", 1))
                 .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())));
-    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
-    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
+    PCollection<KV<String, Iterable<Integer>>> grouped =
+        input.apply(GroupByKey.<String, Integer>create());
+    AppliedPTransform<?, ?, ?> producer = DirectGraphs.getProducer(grouped);
+    PTransformReplacement<
+            PCollection<KV<String, Integer>>, PCollection<KV<String, Iterable<Integer>>>>
+        replacement = factory.getReplacementTransform((AppliedPTransform) producer);
+    assertThat(replacement.getInput(), Matchers.<PCollection<?>>equalTo(input));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
deleted file mode 100644
index 4bbf924..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import static org.junit.Assert.assertThat;
-
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.values.PCollection;
-import org.hamcrest.Matchers;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/**
- * Tests for {@link ParDoMultiOverrideFactory}.
- */
-@RunWith(JUnit4.class)
-public class ParDoMultiOverrideFactoryTest {
-  private ParDoMultiOverrideFactory factory = new ParDoMultiOverrideFactory();
-
-  @Test
-  public void getInputSucceeds() {
-    TestPipeline p = TestPipeline.create();
-    PCollection<Integer> input = p.apply(Create.of(1, 2, 3));
-    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
-    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
index 0d909c2..b9c6e64 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
@@ -27,22 +27,17 @@ import com.google.common.collect.Iterables;
 import java.util.Collection;
 import java.util.Collections;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
-import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory.DirectTestStream;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestClock;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestStreamIndex;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TimestampedValue;
-import org.apache.beam.sdk.values.TupleTag;
 import org.hamcrest.Matchers;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -180,11 +175,4 @@ public class TestStreamEvaluatorFactoryTest {
     assertThat(fifthResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE));
     assertThat(fifthResult.getUnprocessedElements(), Matchers.emptyIterable());
   }
-
-  @Test
-  public void overrideFactoryGetInputSucceeds() {
-    DirectTestStreamFactory<?> factory = new DirectTestStreamFactory<>(runner);
-    PBegin begin = factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), p);
-    assertThat(begin.getPipeline(), Matchers.<Pipeline>equalTo(p));
-  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
index 258cb46..6875e1a 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
@@ -30,12 +30,13 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
 import org.apache.beam.sdk.util.PCollectionViews;
@@ -62,9 +63,20 @@ public class ViewOverrideFactoryTest implements Serializable {
     PCollection<Integer> ints = p.apply("CreateContents", Create.of(1, 2, 3));
     final PCollectionView<List<Integer>> view =
         PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder());
-    PTransform<PCollection<Integer>, PCollectionView<List<Integer>>> replacementTransform =
-        factory.getReplacementTransform(CreatePCollectionView.<Integer, List<Integer>>of(view));
-    PCollectionView<List<Integer>> afterReplacement = ints.apply(replacementTransform);
+    PTransformReplacement<PCollection<Integer>, PCollectionView<List<Integer>>>
+        replacementTransform =
+            factory.getReplacementTransform(
+                AppliedPTransform
+                    .<PCollection<Integer>, PCollectionView<List<Integer>>,
+                        CreatePCollectionView<Integer, List<Integer>>>
+                        of(
+                            "foo",
+                            ints.expand(),
+                            view.expand(),
+                            CreatePCollectionView.<Integer, List<Integer>>of(view),
+                            p));
+    PCollectionView<List<Integer>> afterReplacement =
+        ints.apply(replacementTransform.getTransform());
     assertThat(
         "The CreatePCollectionView replacement should return the same View",
         afterReplacement,
@@ -92,9 +104,18 @@ public class ViewOverrideFactoryTest implements Serializable {
     final PCollection<Integer> ints = p.apply("CreateContents", Create.of(1, 2, 3));
     final PCollectionView<List<Integer>> view =
         PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder());
-    PTransform<PCollection<Integer>, PCollectionView<List<Integer>>> replacement =
-        factory.getReplacementTransform(CreatePCollectionView.<Integer, List<Integer>>of(view));
-    ints.apply(replacement);
+    PTransformReplacement<PCollection<Integer>, PCollectionView<List<Integer>>> replacement =
+        factory.getReplacementTransform(
+            AppliedPTransform
+                .<PCollection<Integer>, PCollectionView<List<Integer>>,
+                    CreatePCollectionView<Integer, List<Integer>>>
+                    of(
+                        "foo",
+                        ints.expand(),
+                        view.expand(),
+                        CreatePCollectionView.<Integer, List<Integer>>of(view),
+                        p));
+    ints.apply(replacement.getTransform());
     final AtomicBoolean writeViewVisited = new AtomicBoolean();
     p.traverseTopologically(
         new PipelineVisitor.Defaults() {
@@ -114,11 +135,4 @@ public class ViewOverrideFactoryTest implements Serializable {
 
     assertThat(writeViewVisited.get(), is(true));
   }
-
-  @Test
-  public void overrideFactoryGetInputSucceeds() {
-    ViewOverrideFactory<String, String> factory = new ViewOverrideFactory<>();
-    PCollection<String> input = p.apply(Create.of("foo", "bar"));
-    assertThat(factory.getInput(input.expand(), p), equalTo(input));
-  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
index 8720fd1..361850d 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -38,11 +38,13 @@ import java.util.List;
 import java.util.UUID;
 import org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.Sink;
 import org.apache.beam.sdk.io.TextIO;
 import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnTester;
@@ -52,7 +54,9 @@ import org.apache.beam.sdk.util.PCollectionViews;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
-import org.hamcrest.Matchers;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -118,7 +122,15 @@ public class WriteWithShardingFactoryTest {
   @Test
   public void withNoShardingSpecifiedReturnsNewTransform() {
     Write<Object> original = Write.to(new TestSink());
-    assertThat(factory.getReplacementTransform(original), not(equalTo((Object) original)));
+    PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of()));
+
+    AppliedPTransform<PCollection<Object>, PDone, Write<Object>> originalApplication =
+        AppliedPTransform.of(
+            "write", objs.expand(), Collections.<TupleTag<?>, PValue>emptyMap(), original, p);
+
+    assertThat(
+        factory.getReplacementTransform(originalApplication).getTransform(),
+        not(equalTo((Object) original)));
   }
 
   @Test
@@ -195,13 +207,6 @@ public class WriteWithShardingFactoryTest {
     assertThat(shards, containsInAnyOrder(13));
   }
 
-  @Test
-  public void getInputSucceeds() {
-    PCollection<String> original = p.apply(Create.of("foo"));
-    PCollection<?> input = factory.getInput(original.expand(), p);
-    assertThat(input, Matchers.<PCollection<?>>equalTo(original));
-  }
-
   private static class TestSink extends Sink<Object> {
     @Override
     public void validate(PipelineOptions options) {}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
index 70da2b3..0459ef7 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
@@ -18,11 +18,11 @@
 package org.apache.beam.runners.flink;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
 import java.util.List;
 import java.util.Map;
 import org.apache.beam.runners.core.SplittableParDo;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.Pipeline;
@@ -30,9 +30,9 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.util.InstanceBuilder;
@@ -221,46 +221,50 @@ class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator {
   }
 
   private static class ReflectiveOneToOneOverrideFactory<
-      InputT extends PValue,
-      OutputT extends PValue,
-      TransformT extends PTransform<InputT, OutputT>>
-      extends SingleInputOutputOverrideFactory<InputT, OutputT, TransformT> {
-    private final Class<PTransform<InputT, OutputT>> replacement;
+          InputT, OutputT, TransformT extends PTransform<PCollection<InputT>, PCollection<OutputT>>>
+      extends SingleInputOutputOverrideFactory<
+          PCollection<InputT>, PCollection<OutputT>, TransformT> {
+    private final Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement;
     private final FlinkRunner runner;
 
     private ReflectiveOneToOneOverrideFactory(
-        Class<PTransform<InputT, OutputT>> replacement, FlinkRunner runner) {
+        Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement,
+        FlinkRunner runner) {
       this.replacement = replacement;
       this.runner = runner;
     }
 
     @Override
-    public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform) {
-      return InstanceBuilder.ofType(replacement)
-          .withArg(FlinkRunner.class, runner)
-          .withArg((Class<PTransform<InputT, OutputT>>) transform.getClass(), transform)
-          .build();
+    public PTransformReplacement<PCollection<InputT>, PCollection<OutputT>> getReplacementTransform(
+        AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, TransformT> transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          InstanceBuilder.ofType(replacement)
+              .withArg(FlinkRunner.class, runner)
+              .withArg(
+                  (Class<PTransform<PCollection<InputT>, PCollection<OutputT>>>)
+                      transform.getTransform().getClass(),
+                  transform.getTransform())
+              .build());
     }
   }
 
   /**
-   * A {@link PTransformOverrideFactory} that overrides a
-   * <a href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a> with
-   * {@link SplittableParDo}.
+   * A {@link PTransformOverrideFactory} that overrides a <a
+   * href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a> with {@link SplittableParDo}.
    */
   static class SplittableParDoOverrideFactory<InputT, OutputT>
       implements PTransformOverrideFactory<
-            PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> {
+          PCollection<InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> {
     @Override
-    @SuppressWarnings("unchecked")
-    public PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform(
-        ParDo.MultiOutput<InputT, OutputT> transform) {
-      return new SplittableParDo(transform);
-    }
-
-    @Override
-    public PCollection<? extends InputT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<InputT>, PCollectionTuple>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new SplittableParDo<>(transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
index 73f3728..119c9c9 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
@@ -19,19 +19,21 @@ package org.apache.beam.runners.dataflow;
 
 import static com.google.common.base.Preconditions.checkState;
 
-import com.google.common.collect.Iterables;
 import java.util.Map;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
+import org.apache.beam.sdk.transforms.ParDo.SingleOutput;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -85,15 +87,15 @@ public class BatchStatefulParDoOverrides {
           ParDo.SingleOutput<KV<K, InputT>, OutputT>> {
 
     @Override
-    @SuppressWarnings("unchecked")
-    public PTransform<PCollection<KV<K, InputT>>, PCollection<OutputT>> getReplacementTransform(
-        ParDo.SingleOutput<KV<K, InputT>, OutputT> originalParDo) {
-      return new StatefulSingleOutputParDo<>(originalParDo);
-    }
-
-    @Override
-    public PCollection<KV<K, InputT>> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<KV<K, InputT>>, PCollection<OutputT>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, InputT>>, PCollection<OutputT>,
+                    SingleOutput<KV<K, InputT>, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new StatefulSingleOutputParDo<>(transform.getTransform()));
     }
 
     @Override
@@ -108,15 +110,15 @@ public class BatchStatefulParDoOverrides {
           PCollection<KV<K, InputT>>, PCollectionTuple, ParDo.MultiOutput<KV<K, InputT>, OutputT>> {
 
     @Override
-    @SuppressWarnings("unchecked")
-    public PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> getReplacementTransform(
-        ParDo.MultiOutput<KV<K, InputT>, OutputT> originalParDo) {
-      return new StatefulMultiOutputParDo<>(originalParDo);
-    }
-
-    @Override
-    public PCollection<KV<K, InputT>> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<KV<K, InputT>>, PCollectionTuple>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, InputT>>, PCollectionTuple,
+                    MultiOutput<KV<K, InputT>, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new StatefulMultiOutputParDo<>(transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
index ead2712..1565fd1 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
@@ -42,6 +42,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.dataflow.internal.IsmFormat;
 import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecord;
@@ -59,6 +60,7 @@ import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.StandardCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView;
 import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn;
@@ -1404,10 +1406,17 @@ class BatchViewOverrides {
     }
 
     @Override
-    public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
-        final GloballyAsSingletonView<ElemT, ViewT> transform) {
-      return new BatchCombineGloballyAsSingletonView<>(
-          runner, transform.getCombineFn(), transform.getFanout(), transform.getInsertDefault());
+    public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<ElemT>, PCollectionView<ViewT>,
+                    GloballyAsSingletonView<ElemT, ViewT>>
+                transform) {
+      GloballyAsSingletonView<ElemT, ViewT> combine = transform.getTransform();
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new BatchCombineGloballyAsSingletonView<>(
+              runner, combine.getCombineFn(), combine.getFanout(), combine.getInsertDefault()));
     }
 
     private static class BatchCombineGloballyAsSingletonView<ElemT, ViewT>


Mime
View raw message