beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From amits...@apache.org
Subject [1/7] beam git commit: [BEAM-774] Implement Metrics support for Spark runner
Date Wed, 15 Feb 2017 09:29:15 GMT
Repository: beam
Updated Branches:
  refs/heads/master e720a7c43 -> 24ecf6bbf


[BEAM-774] Implement Metrics support for Spark runner


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/8e203ea2
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/8e203ea2
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/8e203ea2

Branch: refs/heads/master
Commit: 8e203ea2ef34f01eacaa99eb43294f143532ecc3
Parents: e720a7c
Author: Aviem Zur <aviemzur@gmail.com>
Authored: Wed Jan 11 14:42:53 2017 +0200
Committer: Sela <ansela@paypal.com>
Committed: Wed Feb 15 11:10:47 2017 +0200

----------------------------------------------------------------------
 .../beam/runners/spark/SparkPipelineResult.java |   5 +-
 .../apache/beam/runners/spark/SparkRunner.java  |   3 +
 .../runners/spark/metrics/MetricAggregator.java | 113 ++++++++
 .../spark/metrics/MetricsAccumulator.java       |  60 ++++
 .../spark/metrics/MetricsAccumulatorParam.java  |  42 +++
 .../spark/metrics/SparkMetricResults.java       | 188 ++++++++++++
 .../spark/metrics/SparkMetricsContainer.java    | 288 +++++++++++++++++++
 .../runners/spark/metrics/package-info.java     |  20 ++
 .../runners/spark/translation/DoFnFunction.java |  26 +-
 .../translation/DoFnRunnerWithMetrics.java      |  98 +++++++
 .../spark/translation/EvaluationContext.java    |   4 +
 .../spark/translation/SparkContextFactory.java  |   2 +
 .../spark/translation/TransformTranslator.java  |  13 +-
 .../streaming/StreamingTransformTranslator.java |  12 +-
 .../apache/beam/sdk/metrics/MetricMatchers.java |  96 +++++++
 .../apache/beam/sdk/metrics/MetricsTest.java    |  24 +-
 16 files changed, 974 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
index b1027a6..d0d5569 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java
@@ -24,6 +24,8 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
+import org.apache.beam.runners.spark.metrics.SparkMetricResults;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
 import org.apache.beam.sdk.AggregatorRetrievalException;
 import org.apache.beam.sdk.AggregatorValues;
@@ -122,7 +124,8 @@ public abstract class SparkPipelineResult implements PipelineResult {
 
   @Override
   public MetricResults metrics() {
-    throw new UnsupportedOperationException("The SparkRunner does not currently support metrics.");
+    return new SparkMetricResults(
+        SparkMetricsContainer.getAccumulator(SparkContextFactory.EMPTY_CONTEXT));
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
index 46492f8..cc20a30 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
@@ -29,6 +29,7 @@ import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.aggregators.metrics.AggregatorMetricSource;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.EvaluationContext;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
@@ -141,6 +142,8 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> {
     final Accumulator<NamedAggregators> accum =
         SparkAggregators.getOrCreateNamedAggregators(jsc, maybeCheckpointDir);
     final NamedAggregators initialValue = accum.value();
+    // Instantiate metrics accumulator
+    SparkMetricsContainer.getAccumulator(jsc);
 
     if (opts.getEnableSparkMetricSinks()) {
       final MetricsSystem metricsSystem = SparkEnv$.MODULE$.get().metricsSystem();

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricAggregator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricAggregator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricAggregator.java
new file mode 100644
index 0000000..79e49ce
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricAggregator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.metrics;
+
+import java.io.Serializable;
+import org.apache.beam.sdk.metrics.DistributionData;
+import org.apache.beam.sdk.metrics.MetricKey;
+
+
+/**
+ * Metric values wrapper which adds aggregation methods.
+ * @param <ValueT> Metric value type.
+ */
+abstract class MetricAggregator<ValueT> implements Serializable {
+  private final MetricKey key;
+  protected ValueT value;
+
+  private MetricAggregator(MetricKey key, ValueT value) {
+    this.key = key;
+    this.value = value;
+  }
+
+  public MetricKey getKey() {
+    return key;
+  }
+
+  public ValueT getValue() {
+    return value;
+  }
+
+  @SuppressWarnings("unused")
+  abstract MetricAggregator<ValueT> updated(ValueT update);
+
+  static class CounterAggregator extends MetricAggregator<Long> {
+    CounterAggregator(MetricKey key, Long value) {
+      super(key, value);
+    }
+
+    @Override
+    CounterAggregator updated(Long counterUpdate) {
+      value = value + counterUpdate;
+      return this;
+    }
+  }
+
+  static class DistributionAggregator extends MetricAggregator<DistributionData> {
+    DistributionAggregator(MetricKey key, DistributionData value) {
+      super(key, value);
+    }
+
+    @Override
+    DistributionAggregator updated(DistributionData distributionUpdate) {
+      this.value = new SparkDistributionData(this.value.combine(distributionUpdate));
+      return this;
+    }
+  }
+
+  static class SparkDistributionData extends DistributionData implements Serializable {
+    private final long sum;
+    private final long count;
+    private final long min;
+    private final long max;
+
+    SparkDistributionData(DistributionData original) {
+      this.sum = original.sum();
+      this.count = original.count();
+      this.min = original.min();
+      this.max = original.max();
+    }
+
+    @Override
+    public long sum() {
+      return sum;
+    }
+
+    @Override
+    public long count() {
+      return count;
+    }
+
+    @Override
+    public long min() {
+      return min;
+    }
+
+    @Override
+    public long max() {
+      return max;
+    }
+  }
+
+  static <T> MetricAggregator<T> updated(MetricAggregator<T> metricAggregator, Object updateValue) {
+    //noinspection unchecked
+    return metricAggregator.updated((T) updateValue);
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
new file mode 100644
index 0000000..b8f0094
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulator.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.metrics;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.JavaSparkContext;
+
+
+/**
+ * For resilience, {@link Accumulator Accumulators} are required to be wrapped in a Singleton.
+ * @see <a href="https://spark.apache.org/docs/1.6.3/streaming-programming-guide.html#accumulators-and-broadcast-variables">accumulators</a>
+ */
+class MetricsAccumulator {
+
+  private static volatile Accumulator<SparkMetricsContainer> instance = null;
+
+  static Accumulator<SparkMetricsContainer> getInstance(JavaSparkContext jsc) {
+    if (instance == null) {
+      if (jsc == null) {
+        throw new IllegalStateException("Metrics accumulator has not been instantiated");
+      }
+      synchronized (MetricsAccumulator.class) {
+        if (instance == null) {
+          // TODO: currently when recovering from checkpoint, Spark does not recover the
+          // last known Accumulator value. The SparkRunner should be able to persist and recover
+          // the SparkMetricsContainer in order to recover metrics as well.
+          SparkMetricsContainer initialValue = new SparkMetricsContainer();
+          instance = jsc.sc().accumulator(initialValue, "Beam.Metrics",
+              new MetricsAccumulatorParam());
+        }
+      }
+    }
+    return instance;
+  }
+
+  @SuppressWarnings("unused")
+  @VisibleForTesting
+  static void clear() {
+    synchronized (MetricsAccumulator.class) {
+      instance = null;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
new file mode 100644
index 0000000..032e283
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/MetricsAccumulatorParam.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.metrics;
+
+import org.apache.spark.AccumulatorParam;
+
+
+/**
+ * Metrics accumulator param.
+ */
+class MetricsAccumulatorParam implements AccumulatorParam<SparkMetricsContainer> {
+  @Override
+  public SparkMetricsContainer addAccumulator(SparkMetricsContainer c1, SparkMetricsContainer c2) {
+    return c1.merge(c2);
+  }
+
+  @Override
+  public SparkMetricsContainer addInPlace(SparkMetricsContainer c1, SparkMetricsContainer c2) {
+    return c1.merge(c2);
+  }
+
+  @Override
+  public SparkMetricsContainer zero(SparkMetricsContainer initialValue) {
+    return initialValue;
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java
new file mode 100644
index 0000000..aea7b2e
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricResults.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.metrics;
+
+import com.google.common.base.Function;
+import com.google.common.base.Objects;
+import com.google.common.base.Predicate;
+import com.google.common.collect.FluentIterable;
+import java.util.Set;
+import org.apache.beam.runners.spark.metrics.MetricAggregator.CounterAggregator;
+import org.apache.beam.runners.spark.metrics.MetricAggregator.DistributionAggregator;
+import org.apache.beam.sdk.metrics.DistributionResult;
+import org.apache.beam.sdk.metrics.MetricKey;
+import org.apache.beam.sdk.metrics.MetricName;
+import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.MetricQueryResults;
+import org.apache.beam.sdk.metrics.MetricResult;
+import org.apache.beam.sdk.metrics.MetricResults;
+import org.apache.beam.sdk.metrics.MetricsFilter;
+import org.apache.spark.Accumulator;
+
+
+/**
+ * Implementation of {@link MetricResults} for the Spark Runner.
+ */
+public class SparkMetricResults extends MetricResults {
+  private final Accumulator<SparkMetricsContainer> metricsAccum;
+
+  public SparkMetricResults(Accumulator<SparkMetricsContainer> metricsAccum) {
+    this.metricsAccum = metricsAccum;
+  }
+
+  @Override
+  public MetricQueryResults queryMetrics(MetricsFilter filter) {
+    return new SparkMetricQueryResults(filter);
+  }
+
+  private class SparkMetricQueryResults implements MetricQueryResults {
+    private final MetricsFilter filter;
+
+    SparkMetricQueryResults(MetricsFilter filter) {
+      this.filter = filter;
+    }
+
+    @Override
+    public Iterable<MetricResult<Long>> counters() {
+      return
+          FluentIterable
+              .from(metricsAccum.value().getCounters())
+              .filter(matchesFilter(filter))
+              .transform(TO_COUNTER_RESULT)
+              .toList();
+    }
+
+    @Override
+    public Iterable<MetricResult<DistributionResult>> distributions() {
+      return
+          FluentIterable
+              .from(metricsAccum.value().getDistributions())
+              .filter(matchesFilter(filter))
+              .transform(TO_DISTRIBUTION_RESULT)
+              .toList();
+    }
+
+    private Predicate<MetricAggregator<?>> matchesFilter(final MetricsFilter filter) {
+      return new Predicate<MetricAggregator<?>>() {
+        @Override
+        public boolean apply(MetricAggregator<?> metricResult) {
+          return matches(filter, metricResult.getKey());
+        }
+      };
+    }
+
+    private boolean matches(MetricsFilter filter, MetricKey key) {
+      return matchesName(key.metricName(), filter.names())
+          && matchesScope(key.stepName(), filter.steps());
+    }
+
+    private boolean matchesName(MetricName metricName, Set<MetricNameFilter> nameFilters) {
+      if (nameFilters.isEmpty()) {
+        return true;
+      }
+
+      for (MetricNameFilter nameFilter : nameFilters) {
+        if ((nameFilter.getName() == null || nameFilter.getName().equals(metricName.name()))
+            && Objects.equal(metricName.namespace(), nameFilter.getNamespace())) {
+          return true;
+        }
+      }
+
+      return false;
+    }
+
+    private boolean matchesScope(String actualScope, Set<String> scopes) {
+      if (scopes.isEmpty() || scopes.contains(actualScope)) {
+        return true;
+      }
+
+      for (String scope : scopes) {
+        if (actualScope.startsWith(scope)) {
+          return true;
+        }
+      }
+
+      return false;
+    }
+  }
+
+  private static final Function<DistributionAggregator, MetricResult<DistributionResult>>
+      TO_DISTRIBUTION_RESULT =
+      new Function<DistributionAggregator, MetricResult<DistributionResult>>() {
+        @Override
+        public MetricResult<DistributionResult>
+        apply(DistributionAggregator metricResult) {
+          if (metricResult != null) {
+            MetricKey key = metricResult.getKey();
+            return new SparkMetricResult<>(key.metricName(), key.stepName(),
+                metricResult.getValue().extractResult());
+          } else {
+            return null;
+          }
+        }
+      };
+
+  private static final Function<CounterAggregator, MetricResult<Long>>
+      TO_COUNTER_RESULT =
+      new Function<CounterAggregator, MetricResult<Long>>() {
+        @Override
+        public MetricResult<Long>
+        apply(CounterAggregator metricResult) {
+          if (metricResult != null) {
+            MetricKey key = metricResult.getKey();
+            return new SparkMetricResult<>(key.metricName(), key.stepName(),
+                metricResult.getValue());
+          } else {
+            return null;
+          }
+        }
+      };
+
+  private static class SparkMetricResult<T> implements MetricResult<T> {
+    private final MetricName name;
+    private final String step;
+    private final T result;
+
+    SparkMetricResult(MetricName name, String step, T result) {
+      this.name = name;
+      this.step = step;
+      this.result = result;
+    }
+
+    @Override
+    public MetricName name() {
+      return name;
+    }
+
+    @Override
+    public String step() {
+      return step;
+    }
+
+    @Override
+    public T committed() {
+      return result;
+    }
+
+    @Override
+    public T attempted() {
+      return result;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
new file mode 100644
index 0000000..0bf9612
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/SparkMetricsContainer.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.metrics;
+
+import com.google.common.base.Function;
+import com.google.common.base.Predicate;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.common.collect.FluentIterable;
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import org.apache.beam.runners.spark.metrics.MetricAggregator.CounterAggregator;
+import org.apache.beam.runners.spark.metrics.MetricAggregator.DistributionAggregator;
+import org.apache.beam.runners.spark.metrics.MetricAggregator.SparkDistributionData;
+import org.apache.beam.sdk.metrics.DistributionData;
+import org.apache.beam.sdk.metrics.MetricKey;
+import org.apache.beam.sdk.metrics.MetricName;
+import org.apache.beam.sdk.metrics.MetricUpdates;
+import org.apache.beam.sdk.metrics.MetricUpdates.MetricUpdate;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Spark accumulator value which holds all {@link MetricsContainer}s, aggregates and merges them.
+ */
+public class SparkMetricsContainer implements Serializable {
+  private static final Logger LOG = LoggerFactory.getLogger(SparkMetricsContainer.class);
+
+  private transient volatile LoadingCache<String, MetricsContainer> metricsContainers;
+
+  private final Map<MetricKey, MetricAggregator<?>> metrics = new HashMap<>();
+
+  SparkMetricsContainer() {}
+
+  public static Accumulator<SparkMetricsContainer> getAccumulator(JavaSparkContext jsc) {
+    return MetricsAccumulator.getInstance(jsc);
+  }
+
+  public MetricsContainer getContainer(String stepName) {
+    if (metricsContainers == null) {
+      synchronized (this) {
+        if (metricsContainers == null) {
+          metricsContainers = CacheBuilder.<String, SparkMetricsContainer>newBuilder()
+              .build(new MetricsContainerCacheLoader());
+        }
+      }
+    }
+    try {
+      return metricsContainers.get(stepName);
+    } catch (ExecutionException e) {
+      LOG.error("Error while creating metrics container", e);
+      return null;
+    }
+  }
+
+  Collection<CounterAggregator> getCounters() {
+    return
+        FluentIterable
+            .from(metrics.values())
+            .filter(IS_COUNTER)
+            .transform(TO_COUNTER)
+            .toList();
+  }
+
+  private static final Predicate<MetricAggregator<?>> IS_COUNTER =
+      new Predicate<MetricAggregator<?>>() {
+        @Override
+        public boolean apply(MetricAggregator<?> input) {
+          return (input instanceof CounterAggregator);
+        }
+      };
+
+  private static final Function<MetricAggregator<?>, CounterAggregator> TO_COUNTER =
+      new Function<MetricAggregator<?>,
+          CounterAggregator>() {
+        @Override
+        public CounterAggregator apply(MetricAggregator<?> metricAggregator) {
+          return (CounterAggregator) metricAggregator;
+        }
+      };
+
+  Collection<DistributionAggregator> getDistributions() {
+    return
+        FluentIterable
+            .from(metrics.values())
+            .filter(IS_DISTRIBUTION)
+            .transform(TO_DISTRIBUTION)
+            .toList();
+  }
+
+  private static final Predicate<MetricAggregator<?>> IS_DISTRIBUTION =
+      new Predicate<MetricAggregator<?>>() {
+        @Override
+        public boolean apply(MetricAggregator<?> input) {
+          return (input instanceof DistributionAggregator);
+        }
+      };
+
+  private static final Function<MetricAggregator<?>, DistributionAggregator> TO_DISTRIBUTION =
+      new Function<MetricAggregator<?>, DistributionAggregator>() {
+        @Override
+        public DistributionAggregator apply(MetricAggregator<?> metricAggregator) {
+          return (DistributionAggregator) metricAggregator;
+        }
+      };
+
+  SparkMetricsContainer merge(SparkMetricsContainer other) {
+    return
+        new SparkMetricsContainer()
+            .updated(this.getAggregators())
+            .updated(other.getAggregators());
+  }
+
+  private Collection<MetricAggregator<?>> getAggregators() {
+    return metrics.values();
+  }
+
+  private void writeObject(ObjectOutputStream out) throws IOException {
+    materialize();
+    out.defaultWriteObject();
+  }
+
+  private void materialize() {
+    if (metricsContainers != null) {
+      for (MetricsContainer container : metricsContainers.asMap().values()) {
+        MetricUpdates cumulative = container.getCumulative();
+        updated(Iterables.transform(cumulative.counterUpdates(), TO_COUNTER_AGGREGATOR));
+        updated(Iterables.transform(cumulative.distributionUpdates(), TO_DISTRIBUTION_AGGREGATOR));
+      }
+    }
+  }
+
+  private static final Function<MetricUpdate<Long>, MetricAggregator<?>>
+      TO_COUNTER_AGGREGATOR = new Function<MetricUpdate<Long>, MetricAggregator<?>>() {
+    @SuppressWarnings("ConstantConditions")
+    @Override
+    public CounterAggregator
+    apply(MetricUpdate<Long> update) {
+      return update != null ? new CounterAggregator(new SparkMetricKey(update.getKey()),
+          update.getUpdate()) : null;
+    }
+  };
+
+  private static final Function<MetricUpdate<DistributionData>, MetricAggregator<?>>
+      TO_DISTRIBUTION_AGGREGATOR =
+      new Function<MetricUpdate<DistributionData>, MetricAggregator<?>>() {
+        @SuppressWarnings("ConstantConditions")
+        @Override
+        public DistributionAggregator
+        apply(MetricUpdate<DistributionData> update) {
+          return update != null ? new DistributionAggregator(new SparkMetricKey(update.getKey()),
+              new SparkDistributionData(update.getUpdate())) : null;
+        }
+      };
+
+  private SparkMetricsContainer updated(Iterable<MetricAggregator<?>> updates) {
+    for (MetricAggregator<?> update : updates) {
+      MetricKey key = update.getKey();
+      MetricAggregator<?> current = metrics.get(key);
+      Object updateValue = update.getValue();
+      metrics.put(new SparkMetricKey(key),
+          current != null ? MetricAggregator.updated(current, updateValue) : update);
+    }
+    return this;
+  }
+
+  private static class MetricsContainerCacheLoader extends CacheLoader<String, MetricsContainer> {
+    @SuppressWarnings("NullableProblems")
+    @Override
+    public MetricsContainer load(String stepName) throws Exception {
+      return new MetricsContainer(stepName);
+    }
+  }
+
+  private static class SparkMetricKey extends MetricKey implements Serializable {
+    private final String stepName;
+    private final MetricName metricName;
+
+    SparkMetricKey(MetricKey original) {
+      this.stepName = original.stepName();
+      MetricName metricName = original.metricName();
+      this.metricName = new SparkMetricName(metricName.namespace(), metricName.name());
+    }
+
+    @Override
+    public String stepName() {
+      return stepName;
+    }
+
+    @Override
+    public MetricName metricName() {
+      return metricName;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (o == this) {
+        return true;
+      }
+      if (o instanceof MetricKey) {
+        MetricKey that = (MetricKey) o;
+        return (this.stepName.equals(that.stepName()))
+            && (this.metricName.equals(that.metricName()));
+      }
+      return false;
+    }
+
+    @Override
+    public int hashCode() {
+      int h = 1;
+      h *= 1000003;
+      h ^= stepName.hashCode();
+      h *= 1000003;
+      h ^= metricName.hashCode();
+      return h;
+    }
+  }
+
+  private static class SparkMetricName extends MetricName implements Serializable {
+    private final String namespace;
+    private final String name;
+
+    SparkMetricName(String namespace, String name) {
+      this.namespace = namespace;
+      this.name = name;
+    }
+
+    @Override
+    public String namespace() {
+      return namespace;
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (o == this) {
+        return true;
+      }
+      if (o instanceof MetricName) {
+        MetricName that = (MetricName) o;
+        return (this.namespace.equals(that.namespace()))
+            && (this.name.equals(that.name()));
+      }
+      return false;
+    }
+
+    @Override
+    public int hashCode() {
+      int h = 1;
+      h *= 1000003;
+      h ^= namespace.hashCode();
+      h *= 1000003;
+      h ^= name.hashCode();
+      return h;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/package-info.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/package-info.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/package-info.java
new file mode 100644
index 0000000..d0f8203
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/metrics/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Provides internal utilities for implementing Beam metrics using Spark accumulators. */
+package org.apache.beam.runners.spark.metrics;

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index 4fd5e51..11761b6 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -27,6 +27,7 @@ import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.runners.spark.util.SparkSideInputReader;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -47,38 +48,41 @@ import org.apache.spark.api.java.function.FlatMapFunction;
 public class DoFnFunction<InputT, OutputT>
     implements FlatMapFunction<Iterator<WindowedValue<InputT>>, WindowedValue<OutputT>> {
 
-  private final Accumulator<NamedAggregators> accumulator;
+  private final Accumulator<NamedAggregators> aggregatorsAccum;
+  private final Accumulator<SparkMetricsContainer> metricsAccum;
+  private final String stepName;
   private final DoFn<InputT, OutputT> doFn;
   private final SparkRuntimeContext runtimeContext;
   private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs;
   private final WindowingStrategy<?, ?> windowingStrategy;
 
   /**
-   * @param accumulator       The Spark {@link Accumulator} that backs the Beam Aggregators.
+   * @param aggregatorsAccum  The Spark {@link Accumulator} that backs the Beam Aggregators.
    * @param doFn              The {@link DoFn} to be wrapped.
    * @param runtimeContext    The {@link SparkRuntimeContext}.
    * @param sideInputs        Side inputs used in this {@link DoFn}.
    * @param windowingStrategy Input {@link WindowingStrategy}.
    */
   public DoFnFunction(
-      Accumulator<NamedAggregators> accumulator,
+      Accumulator<NamedAggregators> aggregatorsAccum,
+      Accumulator<SparkMetricsContainer> metricsAccum,
+      String stepName,
       DoFn<InputT, OutputT> doFn,
       SparkRuntimeContext runtimeContext,
       Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs,
       WindowingStrategy<?, ?> windowingStrategy) {
-
-    this.accumulator = accumulator;
+    this.aggregatorsAccum = aggregatorsAccum;
+    this.metricsAccum = metricsAccum;
+    this.stepName = stepName;
     this.doFn = doFn;
     this.runtimeContext = runtimeContext;
     this.sideInputs = sideInputs;
     this.windowingStrategy = windowingStrategy;
   }
 
-
   @Override
   public Iterable<WindowedValue<OutputT>> call(
       Iterator<WindowedValue<InputT>> iter) throws Exception {
-
     DoFnOutputManager outputManager = new DoFnOutputManager();
 
     DoFnRunner<InputT, OutputT> doFnRunner =
@@ -91,10 +95,14 @@ public class DoFnFunction<InputT, OutputT>
             },
             Collections.<TupleTag<?>>emptyList(),
             new SparkProcessContext.NoOpStepContext(),
-            new SparkAggregators.Factory(runtimeContext, accumulator),
+            new SparkAggregators.Factory(runtimeContext, aggregatorsAccum),
             windowingStrategy);
 
-    return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter);
+    DoFnRunner<InputT, OutputT> doFnRunnerWithMetrics =
+        new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum);
+
+    return new SparkProcessContext<>(doFn, doFnRunnerWithMetrics, outputManager)
+        .processPartition(iter);
   }
 
   private class DoFnOutputManager

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
new file mode 100644
index 0000000..d9366ca
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnRunnerWithMetrics.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import java.io.Closeable;
+import java.io.IOException;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.TimeDomain;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.spark.Accumulator;
+import org.joda.time.Instant;
+
+
+/**
+ * DoFnRunner decorator which registers {@link org.apache.beam.sdk.metrics.MetricsContainer}.
+ */
+class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
+  private final DoFnRunner<InputT, OutputT> delegate;
+  private final String stepName;
+  private final Accumulator<SparkMetricsContainer> metricsAccum;
+
+  DoFnRunnerWithMetrics(String stepName, DoFnRunner<InputT, OutputT> delegate,
+                        Accumulator<SparkMetricsContainer>metricsAccum) {
+    this.delegate = delegate;
+    this.stepName = stepName;
+    this.metricsAccum = metricsAccum;
+  }
+
+  @Override
+  public void startBundle() {
+    doWithMetricsContainer(new Runnable() {
+      @Override
+      public void run() {
+        delegate.startBundle();
+      }
+    });
+  }
+
+  @Override
+  public void processElement(final WindowedValue<InputT> elem) {
+    doWithMetricsContainer(new Runnable() {
+      @Override
+      public void run() {
+        delegate.processElement(elem);
+      }
+    });
+  }
+
+  @Override
+  public void onTimer(final String timerId, final BoundedWindow window, final Instant timestamp,
+                      final TimeDomain timeDomain) {
+    doWithMetricsContainer(new Runnable() {
+      @Override
+      public void run() {
+        delegate.onTimer(timerId, window, timestamp, timeDomain);
+      }
+    });
+  }
+
+  @Override
+  public void finishBundle() {
+    doWithMetricsContainer(new Runnable() {
+      @Override
+      public void run() {
+        delegate.finishBundle();
+      }
+    });
+  }
+
+  private void doWithMetricsContainer(Runnable runnable) {
+    MetricsContainer metricsContainer = metricsAccum.localValue().getContainer(stepName);
+    try (Closeable ignored = MetricsEnvironment.scopedMetricsContainer(metricsContainer)) {
+      runnable.run();
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index 9096d5a..a35aff2 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -89,6 +89,10 @@ public class EvaluationContext {
     this.currentTransform = transform;
   }
 
+  public AppliedPTransform<?, ?, ?> getCurrentTransform() {
+    return currentTransform;
+  }
+
   public <T extends PValue> T getInput(PTransform<T, ?> transform) {
     @SuppressWarnings("unchecked")
     T input = (T) Iterables.getOnlyElement(getInputs(transform)).getValue();

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
index 326838a..bd26ba1 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
@@ -33,6 +33,8 @@ import org.slf4j.LoggerFactory;
 public final class SparkContextFactory {
   private static final Logger LOG = LoggerFactory.getLogger(SparkContextFactory.class);
 
+  public static final JavaSparkContext EMPTY_CONTEXT = null;
+
   /**
    * If the property {@code beam.spark.test.reuseSparkContext} is set to
    * {@code true} then the Spark context will be reused for beam pipelines.

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index f0e339a..3d75142 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -42,6 +42,7 @@ import org.apache.beam.runners.spark.io.hadoop.HadoopIO;
 import org.apache.beam.runners.spark.io.hadoop.ShardNameTemplateHelper;
 import org.apache.beam.runners.spark.io.hadoop.TemplatedAvroKeyOutputFormat;
 import org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -240,6 +241,7 @@ public final class TransformTranslator {
     return new TransformEvaluator<ParDo.Bound<InputT, OutputT>>() {
       @Override
       public void evaluate(ParDo.Bound<InputT, OutputT> transform, EvaluationContext context) {
+        String stepName = context.getCurrentTransform().getFullName();
         DoFn<InputT, OutputT> doFn = transform.getFn();
         rejectStateAndTimers(doFn);
         @SuppressWarnings("unchecked")
@@ -247,13 +249,16 @@ public final class TransformTranslator {
             ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
         WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
-        Accumulator<NamedAggregators> accum =
-            SparkAggregators.getNamedAggregators(context.getSparkContext());
+        JavaSparkContext jsc = context.getSparkContext();
+        Accumulator<NamedAggregators> aggAccum =
+            SparkAggregators.getNamedAggregators(jsc);
+        Accumulator<SparkMetricsContainer> metricsAccum =
+            SparkMetricsContainer.getAccumulator(jsc);
         Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
         context.putDataset(transform,
-            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, doFn,
-                context.getRuntimeContext(), sideInputs, windowingStrategy))));
+            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(aggAccum, metricsAccum,
+                stepName, doFn, context.getRuntimeContext(), sideInputs, windowingStrategy))));
       }
     };
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index a2a1d3b..c9ab2b3 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -30,6 +30,7 @@ import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.io.ConsoleIO;
 import org.apache.beam.runners.spark.io.CreateStream;
 import org.apache.beam.runners.spark.io.SparkUnboundedSource;
+import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.BoundedDataset;
 import org.apache.beam.runners.spark.translation.Dataset;
 import org.apache.beam.runners.spark.translation.DoFnFunction;
@@ -383,6 +384,8 @@ final class StreamingTransformTranslator {
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream();
 
+        final String stepName = context.getCurrentTransform().getFullName();
+
         JavaDStream<WindowedValue<OutputT>> outStream =
             dStream.transform(new Function<JavaRDD<WindowedValue<InputT>>,
                 JavaRDD<WindowedValue<OutputT>>>() {
@@ -390,15 +393,16 @@ final class StreamingTransformTranslator {
           public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws
               Exception {
             final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
-
-            final Accumulator<NamedAggregators> accum =
+            final Accumulator<NamedAggregators> aggAccum =
                 SparkAggregators.getNamedAggregators(jsc);
-
+            final Accumulator<SparkMetricsContainer> metricsAccum =
+                SparkMetricsContainer.getAccumulator(jsc);
             final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
                 TranslationUtils.getSideInputs(transform.getSideInputs(),
                     jsc, pviews);
             return rdd.mapPartitions(
-                new DoFnFunction<>(accum, doFn, runtimeContext, sideInputs, windowingStrategy));
+                new DoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext,
+                    sideInputs, windowingStrategy));
           }
         });
 

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricMatchers.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricMatchers.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricMatchers.java
index 3648c05..5de8894 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricMatchers.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricMatchers.java
@@ -165,6 +165,102 @@ public class MetricMatchers {
     };
   }
 
+  static Matcher<MetricResult<DistributionResult>> distributionAttemptedMinMax(
+      final String namespace, final String name, final String step,
+      final Long attemptedMin, final Long attemptedMax) {
+    return new TypeSafeMatcher<MetricResult<DistributionResult>>() {
+      @Override
+      protected boolean matchesSafely(MetricResult<DistributionResult> item) {
+        return Objects.equals(namespace, item.name().namespace())
+            && Objects.equals(name, item.name().name())
+            && item.step().contains(step)
+            && Objects.equals(attemptedMin, item.attempted().min())
+            && Objects.equals(attemptedMax, item.attempted().max());
+      }
+
+      @Override
+      public void describeTo(Description description) {
+        description
+            .appendText("MetricResult{inNamespace=").appendValue(namespace)
+            .appendText(", name=").appendValue(name)
+            .appendText(", step=").appendValue(step)
+            .appendText(", attemptedMin=").appendValue(attemptedMin)
+            .appendText(", attemptedMax=").appendValue(attemptedMax)
+            .appendText("}");
+      }
+
+      @Override
+      protected void describeMismatchSafely(MetricResult<DistributionResult> item,
+          Description mismatchDescription) {
+        mismatchDescription.appendText("MetricResult{");
+
+        describeMetricsResultMembersMismatch(item, mismatchDescription, namespace, name, step);
+
+        if (!Objects.equals(attemptedMin, item.attempted())) {
+          mismatchDescription
+              .appendText("attemptedMin: ").appendValue(attemptedMin)
+              .appendText(" != ").appendValue(item.attempted());
+        }
+
+        if (!Objects.equals(attemptedMax, item.attempted())) {
+          mismatchDescription
+              .appendText("attemptedMax: ").appendValue(attemptedMax)
+              .appendText(" != ").appendValue(item.attempted());
+        }
+
+        mismatchDescription.appendText("}");
+      }
+    };
+  }
+
+  static Matcher<MetricResult<DistributionResult>> distributionCommittedMinMax(
+      final String namespace, final String name, final String step,
+      final Long committedMin, final Long committedMax) {
+    return new TypeSafeMatcher<MetricResult<DistributionResult>>() {
+      @Override
+      protected boolean matchesSafely(MetricResult<DistributionResult> item) {
+        return Objects.equals(namespace, item.name().namespace())
+            && Objects.equals(name, item.name().name())
+            && item.step().contains(step)
+            && Objects.equals(committedMin, item.committed().min())
+            && Objects.equals(committedMax, item.committed().max());
+      }
+
+      @Override
+      public void describeTo(Description description) {
+        description
+            .appendText("MetricResult{inNamespace=").appendValue(namespace)
+            .appendText(", name=").appendValue(name)
+            .appendText(", step=").appendValue(step)
+            .appendText(", committedMin=").appendValue(committedMin)
+            .appendText(", committedMax=").appendValue(committedMax)
+            .appendText("}");
+      }
+
+      @Override
+      protected void describeMismatchSafely(MetricResult<DistributionResult> item,
+          Description mismatchDescription) {
+        mismatchDescription.appendText("MetricResult{");
+
+        describeMetricsResultMembersMismatch(item, mismatchDescription, namespace, name, step);
+
+        if (!Objects.equals(committedMin, item.committed())) {
+          mismatchDescription
+              .appendText("committedMin: ").appendValue(committedMin)
+              .appendText(" != ").appendValue(item.committed());
+        }
+
+        if (!Objects.equals(committedMax, item.committed())) {
+          mismatchDescription
+              .appendText("committedMax: ").appendValue(committedMax)
+              .appendText(" != ").appendValue(item.committed());
+        }
+
+        mismatchDescription.appendText("}");
+      }
+    };
+  }
+
   private static <T> void describeMetricsResultMembersMismatch(
       MetricResult<T> item,
       Description mismatchDescription,

http://git-wip-us.apache.org/repos/asf/beam/blob/8e203ea2/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
index 9ad0935..57a1d23 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java
@@ -20,6 +20,8 @@ package org.apache.beam.sdk.metrics;
 
 import static org.apache.beam.sdk.metrics.MetricMatchers.attemptedMetricsResult;
 import static org.apache.beam.sdk.metrics.MetricMatchers.committedMetricsResult;
+import static org.apache.beam.sdk.metrics.MetricMatchers.distributionAttemptedMinMax;
+import static org.apache.beam.sdk.metrics.MetricMatchers.distributionCommittedMinMax;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.junit.Assert.assertNull;
@@ -116,8 +118,8 @@ public class MetricsTest implements Serializable {
     PipelineResult result = runPipelineWithMetrics();
 
     MetricQueryResults metrics = result.metrics().queryMetrics(MetricsFilter.builder()
-      .addNameFilter(MetricNameFilter.inNamespace(MetricsTest.class))
-      .build());
+        .addNameFilter(MetricNameFilter.inNamespace(MetricsTest.class))
+        .build());
 
     assertThat(metrics.counters(), hasItem(
         committedMetricsResult(MetricsTest.class.getName(), "count", "MyStep1", 3L)));
@@ -130,6 +132,9 @@ public class MetricsTest implements Serializable {
     assertThat(metrics.distributions(), hasItem(
         committedMetricsResult(MetricsTest.class.getName(), "input", "MyStep2",
             DistributionResult.create(52L, 6L, 5L, 13L))));
+
+    assertThat(metrics.distributions(), hasItem(
+        distributionCommittedMinMax(MetricsTest.class.getName(), "bundle", "MyStep1", 10L, 40L)));
   }
 
 
@@ -154,6 +159,9 @@ public class MetricsTest implements Serializable {
     assertThat(metrics.distributions(), hasItem(
         attemptedMetricsResult(MetricsTest.class.getName(), "input", "MyStep2",
             DistributionResult.create(52L, 6L, 5L, 13L))));
+
+    assertThat(metrics.distributions(), hasItem(
+        distributionAttemptedMinMax(MetricsTest.class.getName(), "bundle", "MyStep1", 10L, 40L)));
   }
 
   private PipelineResult runPipelineWithMetrics() {
@@ -162,6 +170,13 @@ public class MetricsTest implements Serializable {
     pipeline
         .apply(Create.of(5, 8, 13))
         .apply("MyStep1", ParDo.of(new DoFn<Integer, Integer>() {
+          Distribution bundleDist = Metrics.distribution(MetricsTest.class, "bundle");
+
+          @StartBundle
+          public void startBundle(Context c) {
+            bundleDist.update(10L);
+          }
+
           @SuppressWarnings("unused")
           @ProcessElement
           public void processElement(ProcessContext c) {
@@ -172,6 +187,11 @@ public class MetricsTest implements Serializable {
             c.output(c.element());
             c.output(c.element());
           }
+
+          @DoFn.FinishBundle
+          public void finishBundle(Context c) {
+            bundleDist.update(40L);
+          }
         }))
         .apply("MyStep2", ParDo.of(new DoFn<Integer, Integer>() {
           @SuppressWarnings("unused")


Mime
View raw message