beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [1/2] incubator-beam git commit: Use input type in coder inference for MapElements and FlatMapElements
Date Fri, 05 Aug 2016 17:10:16 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/master 8daf518bc -> 2b5c6bcb2


Use input type in coder inference for MapElements and FlatMapElements

Previously, the input TypeDescriptor was unknown, so we would fail
to infer a coder for things like MapElements.of(SimpleFunction<T, T>)
even if the input PCollection provided a coder for T.

Now, the input type is plumbed appropriately and the coder is inferred.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ac5cafe
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ac5cafe
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ac5cafe

Branch: refs/heads/master
Commit: 4ac5cafe90a371cf616f97cb202d5016b68616d1
Parents: 8daf518
Author: Kenneth Knowles <klk@google.com>
Authored: Fri Jul 29 10:35:01 2016 -0700
Committer: Kenneth Knowles <klk@google.com>
Committed: Thu Aug 4 20:18:59 2016 -0700

----------------------------------------------------------------------
 .../beam/sdk/transforms/FlatMapElements.java    | 126 +++++++++++++------
 .../apache/beam/sdk/transforms/MapElements.java |  60 +++++----
 .../beam/sdk/transforms/SimpleFunction.java     |  34 +++++
 .../sdk/transforms/FlatMapElementsTest.java     |  48 +++++++
 .../beam/sdk/transforms/MapElementsTest.java    |  84 +++++++++++++
 5 files changed, 288 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
index 694592e..04d993c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
@@ -17,8 +17,10 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
 
 import java.lang.reflect.ParameterizedType;
 
@@ -45,8 +47,16 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>>
{
    * descriptor need not be provided.
    */
   public static <InputT, OutputT> MissingOutputTypeDescriptor<InputT, OutputT>
-  via(SerializableFunction<InputT, ? extends Iterable<OutputT>> fn) {
-    return new MissingOutputTypeDescriptor<>(fn);
+  via(SerializableFunction<? super InputT, ? extends Iterable<OutputT>> fn) {
+
+    // TypeDescriptor interacts poorly with the wildcards needed to correctly express
+    // covariance and contravariance in Java, so instead we cast it to an invariant
+    // function here.
+    @SuppressWarnings("unchecked") // safe covariant cast
+    SerializableFunction<InputT, Iterable<OutputT>> simplerFn =
+        (SerializableFunction<InputT, Iterable<OutputT>>) fn;
+
+    return new MissingOutputTypeDescriptor<>(simplerFn);
   }
 
   /**
@@ -72,16 +82,15 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>>
{
    * <p>To use a Java 8 lambda, see {@link #via(SerializableFunction)}.
    */
   public static <InputT, OutputT> FlatMapElements<InputT, OutputT>
-  via(SimpleFunction<InputT, ? extends Iterable<OutputT>> fn) {
-
-    @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing
-    TypeDescriptor<Iterable<?>> iterableType = (TypeDescriptor) fn.getOutputTypeDescriptor();
-
-    @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType
-    TypeDescriptor<OutputT> outputType =
-        (TypeDescriptor<OutputT>) getIterableElementType(iterableType);
-
-    return new FlatMapElements<>(fn, outputType);
+  via(SimpleFunction<? super InputT, ? extends Iterable<OutputT>> fn) {
+    // TypeDescriptor interacts poorly with the wildcards needed to correctly express
+    // covariance and contravariance in Java, so instead we cast it to an invariant
+    // function here.
+    @SuppressWarnings("unchecked") // safe covariant cast
+    SimpleFunction<InputT, Iterable<OutputT>> simplerFn =
+        (SimpleFunction<InputT, Iterable<OutputT>>) fn;
+
+    return new FlatMapElements<>(simplerFn, fn.getClass());
   }
 
   /**
@@ -91,18 +100,80 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>>
{
    */
   public static final class MissingOutputTypeDescriptor<InputT, OutputT> {
 
-    private final SerializableFunction<InputT, ? extends Iterable<OutputT>> fn;
+    private final SerializableFunction<InputT, Iterable<OutputT>> fn;
 
     private MissingOutputTypeDescriptor(
-        SerializableFunction<InputT, ? extends Iterable<OutputT>> fn) {
+        SerializableFunction<InputT, Iterable<OutputT>> fn) {
       this.fn = fn;
     }
 
     public FlatMapElements<InputT, OutputT> withOutputType(TypeDescriptor<OutputT>
outputType) {
-      return new FlatMapElements<>(fn, outputType);
+      TypeDescriptor<Iterable<OutputT>> iterableOutputType = TypeDescriptors.iterables(outputType);
+
+      return new FlatMapElements<>(
+          SimpleFunction.fromSerializableFunctionWithOutputType(fn,
+              iterableOutputType),
+              fn.getClass());
     }
   }
 
+  //////////////////////////////////////////////////////////////////////////////////////////////////
+
+  private final SimpleFunction<InputT, ? extends Iterable<OutputT>> fn;
+  private final DisplayData.Item<?> fnClassDisplayData;
+
+  private FlatMapElements(
+      SimpleFunction<InputT, ? extends Iterable<OutputT>> fn,
+      Class<?> fnClass) {
+    this.fn = fn;
+    this.fnClassDisplayData = DisplayData.item("flatMapFn", fnClass).withLabel("FlatMap Function");
+  }
+
+  @Override
+  public PCollection<OutputT> apply(PCollection<InputT> input) {
+    return input.apply(
+        "FlatMap",
+        ParDo.of(
+            new DoFn<InputT, OutputT>() {
+              private static final long serialVersionUID = 0L;
+
+              @ProcessElement
+              public void processElement(ProcessContext c) {
+                for (OutputT element : fn.apply(c.element())) {
+                  c.output(element);
+                }
+              }
+
+              @Override
+              public TypeDescriptor<InputT> getInputTypeDescriptor() {
+                return fn.getInputTypeDescriptor();
+              }
+
+              @Override
+              public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+                @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing
+                TypeDescriptor<Iterable<?>> iterableType =
+                    (TypeDescriptor) fn.getOutputTypeDescriptor();
+
+                @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType
+                TypeDescriptor<OutputT> outputType =
+                    (TypeDescriptor<OutputT>) getIterableElementType(iterableType);
+
+                return outputType;
+              }
+            }));
+  }
+
+  @Override
+  public void populateDisplayData(DisplayData.Builder builder) {
+    super.populateDisplayData(builder);
+    builder.add(fnClassDisplayData);
+  }
+
+  /**
+   * Does a best-effort job of getting the best {@link TypeDescriptor} for the type of the
+   * elements contained in the iterable described by the given {@link TypeDescriptor}.
+   */
   private static TypeDescriptor<?> getIterableElementType(
       TypeDescriptor<Iterable<?>> iterableTypeDescriptor) {
 
@@ -118,29 +189,4 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>>
{
         (ParameterizedType) iterableTypeDescriptor.getSupertype(Iterable.class).getType();
     return TypeDescriptor.of(iterableType.getActualTypeArguments()[0]);
   }
-
-  //////////////////////////////////////////////////////////////////////////////////////////////////
-
-  private final SerializableFunction<InputT, ? extends Iterable<OutputT>> fn;
-  private final transient TypeDescriptor<OutputT> outputType;
-
-  private FlatMapElements(
-      SerializableFunction<InputT, ? extends Iterable<OutputT>> fn,
-      TypeDescriptor<OutputT> outputType) {
-    this.fn = fn;
-    this.outputType = outputType;
-  }
-
-  @Override
-  public PCollection<OutputT> apply(PCollection<InputT> input) {
-    return input.apply("Map", ParDo.of(new DoFn<InputT, OutputT>() {
-      private static final long serialVersionUID = 0L;
-      @ProcessElement
-      public void processElement(ProcessContext c) {
-        for (OutputT element : fn.apply(c.element())) {
-          c.output(element);
-        }
-      }
-    })).setTypeDescriptorInternal(outputType);
-  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
index b7b9a5f..429d3fc 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
@@ -67,9 +67,9 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>>
{
    *     }));
    * }</pre>
    */
-  public static <InputT, OutputT> MapElements<InputT, OutputT>
-  via(final SimpleFunction<InputT, OutputT> fn) {
-    return new MapElements<>(fn, fn.getOutputTypeDescriptor());
+  public static <InputT, OutputT> MapElements<InputT, OutputT> via(
+      final SimpleFunction<InputT, OutputT> fn) {
+    return new MapElements<>(fn, fn.getClass());
   }
 
   /**
@@ -85,42 +85,54 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>>
{
       this.fn = fn;
     }
 
-    public MapElements<InputT, OutputT> withOutputType(TypeDescriptor<OutputT>
outputType) {
-      return new MapElements<>(fn, outputType);
+    public MapElements<InputT, OutputT> withOutputType(final TypeDescriptor<OutputT>
outputType) {
+      return new MapElements<>(
+          SimpleFunction.fromSerializableFunctionWithOutputType(fn, outputType), fn.getClass());
     }
+
   }
 
   ///////////////////////////////////////////////////////////////////
 
-  private final SerializableFunction<InputT, OutputT> fn;
-  private final transient TypeDescriptor<OutputT> outputType;
+  private final SimpleFunction<InputT, OutputT> fn;
+  private final DisplayData.Item<?> fnClassDisplayData;
 
-  private MapElements(
-      SerializableFunction<InputT, OutputT> fn,
-      TypeDescriptor<OutputT> outputType) {
+  private MapElements(SimpleFunction<InputT, OutputT> fn, Class<?> fnClass) {
     this.fn = fn;
-    this.outputType = outputType;
+    this.fnClassDisplayData = DisplayData.item("mapFn", fnClass).withLabel("Map Function");
   }
 
   @Override
   public PCollection<OutputT> apply(PCollection<InputT> input) {
-    return input.apply("Map", ParDo.of(new DoFn<InputT, OutputT>() {
-      @ProcessElement
-      public void processElement(ProcessContext c) {
-        c.output(fn.apply(c.element()));
-      }
-
-      @Override
-      public void populateDisplayData(DisplayData.Builder builder) {
-        MapElements.this.populateDisplayData(builder);
-      }
-    })).setTypeDescriptorInternal(outputType);
+    return input.apply(
+        "Map",
+        ParDo.of(
+            new DoFn<InputT, OutputT>() {
+              @ProcessElement
+              public void processElement(ProcessContext c) {
+                c.output(fn.apply(c.element()));
+              }
+
+              @Override
+              public void populateDisplayData(DisplayData.Builder builder) {
+                MapElements.this.populateDisplayData(builder);
+              }
+
+              @Override
+              public TypeDescriptor<InputT> getInputTypeDescriptor() {
+                return fn.getInputTypeDescriptor();
+              }
+
+              @Override
+              public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+                return fn.getOutputTypeDescriptor();
+              }
+            }));
   }
 
   @Override
   public void populateDisplayData(DisplayData.Builder builder) {
     super.populateDisplayData(builder);
-    builder.add(DisplayData.item("mapFn", fn.getClass())
-      .withLabel("Map Function"));
+    builder.add(fnClassDisplayData);
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
index 8894352..6c540cc 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
@@ -27,6 +27,12 @@ import org.apache.beam.sdk.values.TypeDescriptor;
 public abstract class SimpleFunction<InputT, OutputT>
     implements SerializableFunction<InputT, OutputT> {
 
+  public static <InputT, OutputT>
+      SimpleFunction<InputT, OutputT> fromSerializableFunctionWithOutputType(
+          SerializableFunction<InputT, OutputT> fn, TypeDescriptor<OutputT> outputType)
{
+    return new SimpleFunctionWithOutputType<>(fn, outputType);
+  }
+
   /**
    * Returns a {@link TypeDescriptor} capturing what is known statically
    * about the input type of this {@code OldDoFn} instance's most-derived
@@ -52,4 +58,32 @@ public abstract class SimpleFunction<InputT, OutputT>
   public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
     return new TypeDescriptor<OutputT>(this) {};
   }
+
+  /**
+   * A {@link SimpleFunction} built from a {@link SerializableFunction}, having
+   * a known output type that is explicitly set.
+   */
+  private static class SimpleFunctionWithOutputType<InputT, OutputT>
+      extends SimpleFunction<InputT, OutputT> {
+
+    private final SerializableFunction<InputT, OutputT> fn;
+    private final TypeDescriptor<OutputT> outputType;
+
+    public SimpleFunctionWithOutputType(
+        SerializableFunction<InputT, OutputT> fn,
+        TypeDescriptor<OutputT> outputType) {
+      this.fn = fn;
+      this.outputType = outputType;
+    }
+
+    @Override
+    public OutputT apply(InputT input) {
+      return fn.apply(input);
+    }
+
+    @Override
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+      return outputType;
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
index 057fd19..781e143 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+
 import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertThat;
 
@@ -24,6 +26,7 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
@@ -102,6 +105,51 @@ public class FlatMapElementsTest implements Serializable {
     pipeline.run();
   }
 
+  /**
+   * A {@link SimpleFunction} to test that the coder registry can propagate coders
+   * that are bound to type variables.
+   */
+  private static class PolymorphicSimpleFunction<T> extends SimpleFunction<T, Iterable<T>>
{
+    @Override
+    public Iterable<T> apply(T input) {
+      return Collections.<T>emptyList();
+    }
+  }
+
+  /**
+   * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}.
+   */
+  @Test
+  public void testPolymorphicSimpleFunction() throws Exception {
+    Pipeline pipeline = TestPipeline.create();
+    PCollection<Integer> output = pipeline
+        .apply(Create.of(1, 2, 3))
+
+        // This is the function that needs to propagate the input T to output T
+        .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction<Integer>()))
+
+        // This is a consumer to ensure that all coder inference logic is executed.
+        .apply("Test Consumer", MapElements.via(new SimpleFunction<Iterable<Integer>,
Integer>() {
+          @Override
+          public Integer apply(Iterable<Integer> input) {
+            return 42;
+          }
+        }));
+  }
+
+  @Test
+  public void testSimpleFunctionClassDisplayData() {
+    SimpleFunction<Integer, List<Integer>> simpleFn = new SimpleFunction<Integer,
List<Integer>>() {
+      @Override
+      public List<Integer> apply(Integer input) {
+        return Collections.emptyList();
+      }
+    };
+
+    FlatMapElements<?, ?> simpleMap = FlatMapElements.via(simpleFn);
+    assertThat(DisplayData.from(simpleMap), hasDisplayItem("flatMapFn", simpleFn.getClass()));
+  }
+
   @Test
   @Category(NeedsRunner.class)
   public void testVoidValues() throws Exception {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
index b4751d2..dbf8844 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
@@ -54,6 +54,29 @@ public class MapElementsTest implements Serializable {
   public transient ExpectedException thrown = ExpectedException.none();
 
   /**
+   * A {@link SimpleFunction} to test that the coder registry can propagate coders
+   * that are bound to type variables.
+   */
+  private static class PolymorphicSimpleFunction<T> extends SimpleFunction<T, T>
{
+    @Override
+    public T apply(T input) {
+      return input;
+    }
+  }
+
+  /**
+   * A {@link SimpleFunction} to test that the coder registry can propagate coders
+   * that are bound to type variables, when the variable appears nested in the
+   * output.
+   */
+  private static class NestedPolymorphicSimpleFunction<T> extends SimpleFunction<T,
KV<T, String>> {
+    @Override
+    public KV<T, String> apply(T input) {
+      return KV.of(input, "hello");
+    }
+  }
+
+  /**
    * Basic test of {@link MapElements} with a {@link SimpleFunction}.
    */
   @Test
@@ -74,6 +97,55 @@ public class MapElementsTest implements Serializable {
   }
 
   /**
+   * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}.
+   */
+  @Test
+  public void testPolymorphicSimpleFunction() throws Exception {
+    Pipeline pipeline = TestPipeline.create();
+    PCollection<Integer> output = pipeline
+        .apply(Create.of(1, 2, 3))
+
+        // This is the function that needs to propagate the input T to output T
+        .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction<Integer>()))
+
+        // This is a consumer to ensure that all coder inference logic is executed.
+        .apply("Test Consumer", MapElements.via(new SimpleFunction<Integer, Integer>()
{
+          @Override
+          public Integer apply(Integer input) {
+            return input;
+          }
+        }));
+  }
+
+  /**
+   * Test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}
+   * where the type variable occurs nested within other concrete type constructors.
+   */
+  @Test
+  public void testNestedPolymorphicSimpleFunction() throws Exception {
+    Pipeline pipeline = TestPipeline.create();
+    PCollection<Integer> output =
+        pipeline
+            .apply(Create.of(1, 2, 3))
+
+            // This is the function that needs to propagate the input T to output T
+            .apply(
+                "Polymorphic Identity",
+                MapElements.via(new NestedPolymorphicSimpleFunction<Integer>()))
+
+            // This is a consumer to ensure that all coder inference logic is executed.
+            .apply(
+                "Test Consumer",
+                MapElements.via(
+                    new SimpleFunction<KV<Integer, String>, Integer>() {
+                      @Override
+                      public Integer apply(KV<Integer, String> input) {
+                        return 42;
+                      }
+                    }));
+  }
+
+  /**
    * Basic test of {@link MapElements} with a {@link SerializableFunction}. This style is
    * generally discouraged in Java 7, in favor of {@link SimpleFunction}.
    */
@@ -148,6 +220,18 @@ public class MapElementsTest implements Serializable {
   }
 
   @Test
+  public void testSimpleFunctionClassDisplayData() {
+    SimpleFunction<?, ?> simpleFn = new SimpleFunction<Integer, Integer>() {
+      @Override
+      public Integer apply(Integer input) {
+        return input;
+      }
+    };
+
+    MapElements<?, ?> simpleMap = MapElements.via(simpleFn);
+    assertThat(DisplayData.from(simpleMap), hasDisplayItem("mapFn", simpleFn.getClass()));
+  }
+  @Test
   public void testSimpleFunctionDisplayData() {
     SimpleFunction<?, ?> simpleFn = new SimpleFunction<Integer, Integer>() {
       @Override


Mime
View raw message