beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From aviem...@apache.org
Subject [1/2] beam git commit: [BEAM-1397] Introduce IO metrics
Date Tue, 28 Mar 2017 03:52:35 GMT
Repository: beam
Updated Branches:
  refs/heads/master 85b820c37 -> 48fee91f7


[BEAM-1397] Introduce IO metrics


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/65b5f001
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/65b5f001
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/65b5f001

Branch: refs/heads/master
Commit: 65b5f001a4e1790206efe3ff2d418018680ea621
Parents: 85b820c
Author: Aviem Zur <aviemzur@gmail.com>
Authored: Tue Mar 28 06:49:59 2017 +0300
Committer: Aviem Zur <aviemzur@gmail.com>
Committed: Tue Mar 28 06:49:59 2017 +0300

----------------------------------------------------------------------
 .../beam/runners/spark/TestSparkRunner.java     | 14 +++-
 .../apache/beam/runners/spark/io/SourceRDD.java | 51 +++++++++-----
 .../runners/spark/io/SparkUnboundedSource.java  | 48 +++++++++----
 .../spark/metrics/SparkMetricsContainer.java    | 11 ++-
 .../spark/stateful/StateSpecFunctions.java      | 35 +++++++---
 .../spark/translation/TransformTranslator.java  |  3 +-
 .../streaming/StreamingTransformTranslator.java |  4 +-
 .../streaming/StreamingSourceMetricsTest.java   | 71 ++++++++++++++++++++
 .../org/apache/beam/sdk/io/CountingSource.java  |  8 +++
 .../apache/beam/sdk/metrics/MetricsTest.java    | 45 +++++++++++++
 10 files changed, 244 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index e40534f..be9ff2e 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -135,7 +135,12 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
             isOneOf(PipelineResult.State.STOPPED, PipelineResult.State.DONE));
 
         // validate assertion succeeded (at least once).
-        int successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
+        int successAssertions = 0;
+        try {
+          successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
+        } catch (NullPointerException e) {
+          // No assertions registered will cause an NPE here.
+        }
         Integer expectedAssertions = testSparkPipelineOptions.getExpectedAssertions() !=
null
             ? testSparkPipelineOptions.getExpectedAssertions() : expectedNumberOfAssertions;
         assertThat(
@@ -145,7 +150,12 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult>
{
             successAssertions,
             is(expectedAssertions));
         // validate assertion didn't fail.
-        int failedAssertions = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class);
+        int failedAssertions = 0;
+        try {
+          failedAssertions = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class);
+        } catch (NullPointerException e) {
+          // No assertions registered will cause an NPE here.
+        }
         assertThat(
             String.format("Found %d failed assertions.", failedAssertions),
             failedAssertions,

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
index 1a3537f..2f9a827 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
@@ -20,15 +20,21 @@ package org.apache.beam.runners.spark.io;
 
 import static com.google.common.base.Preconditions.checkArgument;
 
+import java.io.Closeable;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
+import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.Source;
 import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.spark.Accumulator;
 import org.apache.spark.Dependency;
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.InterruptibleIterator;
@@ -42,7 +48,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.Option;
 
-
 /**
  * Classes implementing Beam {@link Source} {@link RDD}s.
  */
@@ -59,15 +64,19 @@ public class SourceRDD {
     private final BoundedSource<T> source;
     private final SparkRuntimeContext runtimeContext;
     private final int numPartitions;
+    private final String stepName;
+    private final Accumulator<SparkMetricsContainer> metricsAccum;
 
     // to satisfy Scala API.
     private static final scala.collection.immutable.Seq<Dependency<?>> NIL =
         scala.collection.JavaConversions
           .asScalaBuffer(Collections.<Dependency<?>>emptyList()).toList();
 
-    public Bounded(SparkContext sc,
-                   BoundedSource<T> source,
-                   SparkRuntimeContext runtimeContext) {
+    public Bounded(
+        SparkContext sc,
+        BoundedSource<T> source,
+        SparkRuntimeContext runtimeContext,
+        String stepName) {
       super(sc, NIL, JavaSparkContext$.MODULE$.<WindowedValue<T>>fakeClassTag());
       this.source = source;
       this.runtimeContext = runtimeContext;
@@ -79,6 +88,8 @@ public class SourceRDD {
       // ** the configuration "spark.default.parallelism" takes precedence over all of the
above **
       this.numPartitions = sc.defaultParallelism();
       checkArgument(this.numPartitions > 0, "Number of partitions must be greater than
zero.");
+      this.stepName = stepName;
+      this.metricsAccum = MetricsAccumulator.getInstance();
     }
 
     private static final long DEFAULT_BUNDLE_SIZE = 64 * 1024 * 1024;
@@ -110,6 +121,8 @@ public class SourceRDD {
     @Override
     public scala.collection.Iterator<WindowedValue<T>> compute(final Partition
split,
                                                                TaskContext context) {
+      final MetricsContainer metricsContainer = metricsAccum.localValue().getContainer(stepName);
+
       final Iterator<WindowedValue<T>> iter = new Iterator<WindowedValue<T>>()
{
         @SuppressWarnings("unchecked")
         SourcePartition<T> partition = (SourcePartition<T>) split;
@@ -121,21 +134,27 @@ public class SourceRDD {
 
         @Override
         public boolean hasNext() {
-          try {
-            if (!started) {
-              started = true;
-              finished = !reader.start();
-            } else {
-              finished = !reader.advance();
-            }
-            if (finished) {
-              // safely close the reader if there are no more elements left to read.
+          // Add metrics container to the scope of org.apache.beam.sdk.io.Source.Reader methods
+          // since they may report metrics.
+          try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer))
{
+            try {
+              if (!started) {
+                started = true;
+                finished = !reader.start();
+              } else {
+                finished = !reader.advance();
+              }
+              if (finished) {
+                // safely close the reader if there are no more elements left to read.
+                closeIfNotClosed();
+              }
+              return !finished;
+            } catch (IOException e) {
               closeIfNotClosed();
+              throw new RuntimeException("Failed to read from reader.", e);
             }
-            return !finished;
           } catch (IOException e) {
-            closeIfNotClosed();
-            throw new RuntimeException("Failed to read from reader.", e);
+            throw new RuntimeException(e);
           }
         }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
index 6c047ac..a538907 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
@@ -22,6 +22,8 @@ import java.io.Serializable;
 import java.util.Collections;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
+import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.stateful.StateSpecFunctions;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.runners.spark.translation.streaming.UnboundedDataset;
@@ -33,6 +35,7 @@ import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext$;
 import org.apache.spark.api.java.function.FlatMapFunction;
@@ -62,7 +65,7 @@ import scala.runtime.BoxedUnit;
  * <li>Create a single-element (per-partition) stream, that contains the (partitioned)
  * {@link Source} and an optional {@link CheckpointMark} to start from.</li>
  * <li>Read from within a stateful operation {@link JavaPairInputDStream#mapWithState(StateSpec)}
- * using the {@link StateSpecFunctions#mapSourceFunction(SparkRuntimeContext)} mapping function,
+ * using the {@link StateSpecFunctions#mapSourceFunction} mapping function,
  * which manages the state of the CheckpointMark per partition.</li>
  * <li>Since the stateful operation is a map operation, the read iterator needs to
be flattened,
  * while reporting the properties of the read (such as number of records) to the tracker.</li>
@@ -73,7 +76,8 @@ public class SparkUnboundedSource {
   public static <T, CheckpointMarkT extends CheckpointMark> UnboundedDataset<T>
read(
       JavaStreamingContext jssc,
       SparkRuntimeContext rc,
-      UnboundedSource<T, CheckpointMarkT> source) {
+      UnboundedSource<T, CheckpointMarkT> source,
+      String stepName) {
 
     SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class);
     Long maxRecordsPerBatch = options.getMaxRecordsPerBatch();
@@ -90,7 +94,7 @@ public class SparkUnboundedSource {
         Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream =
         inputDStream.mapWithState(
             StateSpec
-                .function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc))
+                .function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc,
stepName))
                 .numPartitions(sourceDStream.getNumPartitions()));
 
     // set checkpoint duration for read stream, if set.
@@ -106,8 +110,8 @@ public class SparkUnboundedSource {
           }
         });
 
-    // register the ReportingDStream op.
-    new ReportingDStream(metadataDStream.dstream(), id, getSourceName(source, id)).register();
+    // register ReadReportDStream to report information related to this read.
+    new ReadReportDStream(metadataDStream.dstream(), id, getSourceName(source, id)).register();
 
     // output the actual (deserialized) stream.
     WindowedValue.FullWindowedValueCoder<T> coder =
@@ -144,16 +148,20 @@ public class SparkUnboundedSource {
   }
 
   /**
-   * A DStream function that reports the properties of the read to the
+   * A DStream function for reporting information related to the read process.
+   *
+   * <p>Reports properties of the read to
    * {@link org.apache.spark.streaming.scheduler.InputInfoTracker}
-   * for RateControl purposes and visibility.
+   * for RateControl purposes and visibility.</p>
+   * <p>Updates {@link GlobalWatermarkHolder}.</p>
+   * <p>Updates {@link MetricsAccumulator} with metrics reported in the read.</p>
    */
-  private static class ReportingDStream extends DStream<BoxedUnit> {
+  private static class ReadReportDStream extends DStream<BoxedUnit> {
     private final DStream<Metadata> parent;
     private final int inputDStreamId;
     private final String sourceName;
 
-    ReportingDStream(
+    ReadReportDStream(
         DStream<Metadata> parent,
         int inputDStreamId,
         String sourceName) {
@@ -178,6 +186,7 @@ public class SparkUnboundedSource {
     public scala.Option<RDD<BoxedUnit>> compute(Time validTime) {
       // compute parent.
       scala.Option<RDD<Metadata>> parentRDDOpt = parent.getOrCompute(validTime);
+      final Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance();
       long count = 0;
       SparkWatermarks sparkWatermark = null;
       Instant globalLowWatermarkForBatch = BoundedWindow.TIMESTAMP_MIN_VALUE;
@@ -195,6 +204,8 @@ public class SparkUnboundedSource {
           globalHighWatermarkForBatch =
               globalHighWatermarkForBatch.isBefore(partitionHighWatermark)
                   ? partitionHighWatermark : globalHighWatermarkForBatch;
+          // Update metrics reported in the read
+          metricsAccum.value().update(metadata.getMetricsContainer());
         }
 
         sparkWatermark =
@@ -233,23 +244,34 @@ public class SparkUnboundedSource {
     private final long numRecords;
     private final Instant lowWatermark;
     private final Instant highWatermark;
+    private final SparkMetricsContainer metricsContainer;
 
-    public Metadata(long numRecords, Instant lowWatermark, Instant highWatermark) {
+    public Metadata(
+        long numRecords,
+        Instant lowWatermark,
+        Instant highWatermark,
+        SparkMetricsContainer metricsContainer) {
       this.numRecords = numRecords;
+      this.metricsContainer = metricsContainer;
       this.lowWatermark = lowWatermark;
       this.highWatermark = highWatermark;
+      metricsContainer.materialize();
     }
 
-    public long getNumRecords() {
+    long getNumRecords() {
       return numRecords;
     }
 
-    public Instant getLowWatermark() {
+    Instant getLowWatermark() {
       return lowWatermark;
     }
 
-    public Instant getHighWatermark() {
+    Instant getHighWatermark() {
       return highWatermark;
     }
+
+    SparkMetricsContainer getMetricsContainer() {
+      return metricsContainer;
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
index b6aa178..9e94c14 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
@@ -82,7 +82,8 @@ public class SparkMetricsContainer implements Serializable {
     return getInstance().gauges.values();
   }
 
-  SparkMetricsContainer update(SparkMetricsContainer other) {
+  public SparkMetricsContainer update(SparkMetricsContainer other) {
+    other.materialize();
     this.updateCounters(other.counters.values());
     this.updateDistributions(other.distributions.values());
     this.updateGauges(other.gauges.values());
@@ -102,7 +103,12 @@ public class SparkMetricsContainer implements Serializable {
     out.defaultWriteObject();
   }
 
-  private void materialize() {
+  /**
+   * Materialize metrics. Must be called to enable this instance's data to be serialized
correctly.
+   * This method is idempotent.
+   */
+  public void materialize() {
+    // Nullifying metricsContainers makes this method idempotent.
     if (metricsContainers != null) {
       for (MetricsContainer container : metricsContainers.asMap().values()) {
         MetricUpdates cumulative = container.getCumulative();
@@ -110,6 +116,7 @@ public class SparkMetricsContainer implements Serializable {
         this.updateDistributions(cumulative.distributionUpdates());
         this.updateGauges(cumulative.gaugeUpdates());
       }
+      metricsContainers = null;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
index ae5a746..ec4fce3 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/StateSpecFunctions.java
@@ -20,6 +20,7 @@ package org.apache.beam.runners.spark.stateful;
 
 import com.google.common.base.Stopwatch;
 import com.google.common.collect.Iterators;
+import java.io.Closeable;
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
@@ -30,11 +31,14 @@ import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.io.EmptyCheckpointMark;
 import org.apache.beam.runners.spark.io.MicrobatchSource;
 import org.apache.beam.runners.spark.io.SparkUnboundedSource.Metadata;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.Source;
 import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -97,7 +101,7 @@ public class StateSpecFunctions {
   public static <T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
   scala.Function3<Source<T>, scala.Option<CheckpointMarkT>, State<Tuple2<byte[],
Instant>>,
       Tuple2<Iterable<byte[]>, Metadata>> mapSourceFunction(
-           final SparkRuntimeContext runtimeContext) {
+           final SparkRuntimeContext runtimeContext, final String stepName) {
 
     return new SerializableFunction3<Source<T>, Option<CheckpointMarkT>,
         State<Tuple2<byte[], Instant>>, Tuple2<Iterable<byte[]>, Metadata>>()
{
@@ -108,9 +112,14 @@ public class StateSpecFunctions {
           scala.Option<CheckpointMarkT> startCheckpointMark,
           State<Tuple2<byte[], Instant>> state) {
 
-        // source as MicrobatchSource
-        MicrobatchSource<T, CheckpointMarkT> microbatchSource =
-            (MicrobatchSource<T, CheckpointMarkT>) source;
+        SparkMetricsContainer sparkMetricsContainer = new SparkMetricsContainer();
+        MetricsContainer metricsContainer = sparkMetricsContainer.getContainer(stepName);
+        // Add metrics container to the scope of org.apache.beam.sdk.io.Source.Reader methods
+        // since they may report metrics.
+        try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer))
{
+          // source as MicrobatchSource
+          MicrobatchSource<T, CheckpointMarkT> microbatchSource =
+              (MicrobatchSource<T, CheckpointMarkT>) source;
 
         // Initial high/low watermarks.
         Instant lowWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE;
@@ -188,13 +197,17 @@ public class StateSpecFunctions {
           throw new RuntimeException("Failed to read from reader.", e);
         }
 
-        Iterable <byte[]> iterable = new Iterable<byte[]>() {
-          @Override
-          public Iterator<byte[]> iterator() {
-            return Iterators.unmodifiableIterator(readValues.iterator());
-          }
-        };
-        return new Tuple2<>(iterable, new Metadata(readValues.size(), lowWatermark,
highWatermark));
+          Iterable <byte[]> iterable = new Iterable<byte[]>() {
+            @Override
+            public Iterator<byte[]> iterator() {
+              return Iterators.unmodifiableIterator(readValues.iterator());
+            }
+          };
+          return new Tuple2<>(iterable,
+              new Metadata(readValues.size(), lowWatermark, highWatermark, sparkMetricsContainer));
+        } catch (IOException e) {
+          throw new RuntimeException(e);
+        }
       }
     };
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index ffb207a..b57860a 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -543,11 +543,12 @@ public final class TransformTranslator {
     return new TransformEvaluator<Read.Bounded<T>>() {
       @Override
       public void evaluate(Read.Bounded<T> transform, EvaluationContext context) {
+        String stepName = context.getCurrentTransform().getFullName();
         final JavaSparkContext jsc = context.getSparkContext();
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
         // create an RDD from a BoundedSource.
         JavaRDD<WindowedValue<T>> input = new SourceRDD.Bounded<>(
-            jsc.sc(), transform.getSource(), runtimeContext).toJavaRDD();
+            jsc.sc(), transform.getSource(), runtimeContext, stepName).toJavaRDD();
         // cache to avoid re-evaluation of the source by Spark's lazy DAG evaluation.
         context.putDataset(transform, new BoundedDataset<>(input.cache()));
       }

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 25fecf6..b88731c 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -121,12 +121,14 @@ public final class StreamingTransformTranslator {
     return new TransformEvaluator<Read.Unbounded<T>>() {
       @Override
       public void evaluate(Read.Unbounded<T> transform, EvaluationContext context)
{
+        final String stepName = context.getCurrentTransform().getFullName();
         context.putDataset(
             transform,
             SparkUnboundedSource.read(
                 context.getStreamingContext(),
                 context.getRuntimeContext(),
-                transform.getSource()));
+                transform.getSource(),
+                stepName));
       }
 
       @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingSourceMetricsTest.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingSourceMetricsTest.java
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingSourceMetricsTest.java
new file mode 100644
index 0000000..ea76d31
--- /dev/null
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/StreamingSourceMetricsTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation.streaming;
+
+import static org.apache.beam.sdk.metrics.MetricMatchers.attemptedMetricsResult;
+import static org.hamcrest.Matchers.hasItem;
+import static org.junit.Assert.assertThat;
+
+import java.io.Serializable;
+import org.apache.beam.runners.spark.PipelineRule;
+import org.apache.beam.runners.spark.TestSparkPipelineOptions;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.io.CountingInput;
+import org.apache.beam.sdk.io.Source;
+import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.MetricQueryResults;
+import org.apache.beam.sdk.metrics.MetricsFilter;
+import org.junit.Rule;
+import org.junit.Test;
+
+
+/**
+ * Verify metrics support for {@link Source Sources} in streaming pipelines.
+ */
+public class StreamingSourceMetricsTest implements Serializable {
+
+  // Force streaming pipeline using pipeline rule.
+  @Rule
+  public final transient PipelineRule pipelineRule = PipelineRule.streaming();
+
+  @Test
+  public void testUnboundedSourceMetrics() {
+    TestSparkPipelineOptions options = pipelineRule.getOptions();
+
+    Pipeline pipeline = Pipeline.create(options);
+
+    final long numElements = 1000;
+
+    pipeline.apply(CountingInput.unbounded().withMaxNumRecords(numElements));
+
+    PipelineResult pipelineResult = pipeline.run();
+
+    MetricQueryResults metrics =
+        pipelineResult
+            .metrics()
+            .queryMetrics(
+                MetricsFilter.builder()
+                    .addNameFilter(MetricNameFilter.named("io", "elementsRead"))
+                    .build());
+
+    assertThat(metrics.counters(), hasItem(
+        attemptedMetricsResult("io", "elementsRead", "Read(UnboundedCountingSource)", 1000L)));
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
index 4b9ec66..4d1305c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CountingSource.java
@@ -30,6 +30,8 @@ import org.apache.beam.sdk.coders.DefaultCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.io.CountingInput.UnboundedCountingInput;
 import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PCollection;
@@ -207,6 +209,8 @@ public class CountingSource {
   private static class BoundedCountingReader extends OffsetBasedSource.OffsetBasedReader<Long>
{
     private long current;
 
+    private final Counter elementsRead = Metrics.counter("io", "elementsRead");
+
     public BoundedCountingReader(OffsetBasedSource<Long> source) {
       super(source);
     }
@@ -239,6 +243,7 @@ public class CountingSource {
 
     @Override
     protected boolean advanceImpl() throws IOException {
+      elementsRead.inc();
       current++;
       return true;
     }
@@ -368,6 +373,8 @@ public class CountingSource {
     private Instant currentTimestamp;
     private Instant firstStarted;
 
+    private final Counter elementsRead = Metrics.counter("io", "elementsRead");
+
     public UnboundedCountingReader(UnboundedCountingSource source, CounterMark mark) {
       this.source = source;
       if (mark == null) {
@@ -398,6 +405,7 @@ public class CountingSource {
       if (expectedValue() < nextValue) {
         return false;
       }
+      elementsRead.inc();
       current = nextValue;
       currentTimestamp = source.timestampFn.apply(current);
       return true;

http://git-wip-us.apache.org/repos/asf/beam/blob/65b5f001/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
index 265b519..27e8411 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
@@ -29,6 +29,7 @@ import static org.junit.Assert.assertThat;
 
 import java.io.Serializable;
 import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.io.CountingInput;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.UsesAttemptedMetrics;
 import org.apache.beam.sdk.testing.UsesCommittedMetrics;
@@ -36,6 +37,7 @@ import org.apache.beam.sdk.testing.ValidatesRunner;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
 import org.hamcrest.CoreMatchers;
@@ -227,4 +229,47 @@ public class MetricsTest implements Serializable {
     result.waitUntilFinish();
     return result;
   }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesAttemptedMetrics.class})
+  public void testBoundedSourceMetrics() {
+    long numElements = 1000;
+
+    PCollection<Long> input = pipeline.apply(CountingInput.upTo(numElements));
+
+    PipelineResult pipelineResult = pipeline.run();
+
+    MetricQueryResults metrics =
+        pipelineResult
+            .metrics()
+            .queryMetrics(
+                MetricsFilter.builder()
+                    .addNameFilter(MetricNameFilter.named("io", "elementsRead"))
+                    .build());
+
+    assertThat(metrics.counters(), hasItem(
+        attemptedMetricsResult("io", "elementsRead", "Read(BoundedCountingSource)", 1000L)));
+  }
+
+  @Test
+  @Category({ValidatesRunner.class, UsesAttemptedMetrics.class})
+  public void testUnboundedSourceMetrics() {
+    long numElements = 1000;
+
+    PCollection<Long> input = pipeline
+        .apply((CountingInput.unbounded()).withMaxNumRecords(numElements));
+
+    PipelineResult pipelineResult = pipeline.run();
+
+    MetricQueryResults metrics =
+        pipelineResult
+            .metrics()
+            .queryMetrics(
+                MetricsFilter.builder()
+                    .addNameFilter(MetricNameFilter.named("io", "elementsRead"))
+                    .build());
+
+    assertThat(metrics.counters(), hasItem(
+        attemptedMetricsResult("io", "elementsRead", "Read(UnboundedCountingSource)", 1000L)));
+  }
 }


Mime
View raw message