beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From amits...@apache.org
Subject [3/4] incubator-beam git commit: [BEAM-610] Enable spark's checkpointing mechanism for driver-failure recovery in streaming.
Date Wed, 21 Sep 2016 17:25:08 GMT
[BEAM-610] Enable spark's checkpointing mechanism for driver-failure recovery in streaming.

Refactor translation mechanism to support checkpointing of DStream.

Support basic functionality with GroupByKey and ParDo.

Added support for grouping operations.

Added checkpointDir option, using it before execution.

Support Accumulator recovery from checkpoint.

Streaming tests should use JUnit's TemporaryFolder Rule for checkpoint directory.

Support combine optimizations.

Support durable sideInput via Broadcast.

Branches in the pipeline are either Bounded or Unbounded and should be handles so.

Handle flatten/union of Bouned/Unbounded RDD/DStream.

JavaDoc

Rebased on master.

Reuse functionality between batch and streaming translators

Better implementation of streaming/batch pipeline-branch translation.

Move group/combine functions to their own wrapping class.

Fixed missing licenses.

Use VisibleForTesting annotation instead of comment.

Remove Broadcast failure recovery, to be handled separately.

Stop streaming gracefully, so any checkpointing will finish first.

typo + better documentation.

Validate checkpointDir durability.

Add checkpoint duration option.

A more compact streaming tests init with Rules.

A more accurate test, removed broadcast from test as it will be handeled separately.

Bounded/Unbounded translation to be handled by the SparkPipelineTranslator implementation. Evaluator
decides if translateBounded or translateUnbounded according to the visited node's boundeness.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/0feb6499
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/0feb6499
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/0feb6499

Branch: refs/heads/master
Commit: 0feb64994a05de4fe6b1ba178a38d03743b89b7a
Parents: 5c23f49
Author: Sela <ansela@paypal.com>
Authored: Thu Aug 25 23:49:01 2016 +0300
Committer: Sela <ansela@paypal.com>
Committed: Wed Sep 21 20:15:27 2016 +0300

----------------------------------------------------------------------
 .../runners/spark/SparkPipelineOptions.java     |  28 +-
 .../apache/beam/runners/spark/SparkRunner.java  | 121 ++--
 .../spark/aggregators/AccumulatorSingleton.java |  53 ++
 .../runners/spark/translation/DoFnFunction.java |  35 +-
 .../spark/translation/EvaluationContext.java    |  17 +-
 .../translation/GroupCombineFunctions.java      | 262 +++++++++
 .../spark/translation/MultiDoFnFunction.java    |  44 +-
 .../spark/translation/SparkContextFactory.java  |  48 +-
 .../translation/SparkPipelineEvaluator.java     |  57 --
 .../translation/SparkPipelineTranslator.java    |   5 +-
 .../spark/translation/SparkProcessContext.java  |  10 +-
 .../spark/translation/SparkRuntimeContext.java  |  44 +-
 .../spark/translation/TransformTranslator.java  | 473 +++-------------
 .../spark/translation/TranslationUtils.java     | 195 +++++++
 .../SparkRunnerStreamingContextFactory.java     |  98 ++++
 .../streaming/StreamingEvaluationContext.java   |  44 +-
 .../streaming/StreamingTransformTranslator.java | 549 ++++++++++++-------
 .../runners/spark/util/BroadcastHelper.java     |  26 +
 .../runners/spark/ClearAggregatorsRule.java     |  33 ++
 .../beam/runners/spark/SimpleWordCountTest.java |   4 +
 .../spark/translation/SideEffectsTest.java      |   3 +-
 .../streaming/FlattenStreamingTest.java         |  54 +-
 .../streaming/KafkaStreamingTest.java           |  26 +-
 .../RecoverFromCheckpointStreamingTest.java     | 179 ++++++
 .../streaming/SimpleStreamingWordCountTest.java |  25 +-
 .../utils/TestOptionsForStreaming.java          |  55 ++
 .../org/apache/beam/sdk/transforms/Combine.java |   7 +
 27 files changed, 1682 insertions(+), 813 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
index db6b75c..7afb68c 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
@@ -19,9 +19,9 @@
 package org.apache.beam.runners.spark;
 
 import com.fasterxml.jackson.annotation.JsonIgnore;
-
 import org.apache.beam.sdk.options.ApplicationNameOptions;
 import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.DefaultValueFactory;
 import org.apache.beam.sdk.options.Description;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.StreamingOptions;
@@ -48,6 +48,32 @@ public interface SparkPipelineOptions extends PipelineOptions, StreamingOptions,
   Long getBatchIntervalMillis();
   void setBatchIntervalMillis(Long batchInterval);
 
+  @Description("A checkpoint directory for streaming resilience, ignored in batch. "
+      + "For durability, a reliable filesystem such as HDFS/S3/GS is necessary.")
+  @Default.InstanceFactory(TmpCheckpointDirFactory.class)
+  String getCheckpointDir();
+  void setCheckpointDir(String checkpointDir);
+
+  /**
+   * Returns the default checkpoint directory of /tmp/${job.name}.
+   * For testing purposes only. Production applications should use a reliable
+   * filesystem such as HDFS/S3/GS.
+   */
+  static class TmpCheckpointDirFactory implements DefaultValueFactory<String> {
+    @Override
+    public String create(PipelineOptions options) {
+      SparkPipelineOptions sparkPipelineOptions = options.as(SparkPipelineOptions.class);
+      return "file:///tmp/" + sparkPipelineOptions.getJobName();
+    }
+  }
+
+  @Description("The period to checkpoint (in Millis). If not set, Spark will default "
+      + "to Max(slideDuration, Seconds(10)). This PipelineOptions default (-1) will end-up "
+          + "with the described Spark default.")
+  @Default.Long(-1)
+  Long getCheckpointDurationMillis();
+  void setCheckpointDurationMillis(Long durationMillis);
+
   @Description("Enable/disable sending aggregator values to Spark's metric sinks")
   @Default.Boolean(true)
   Boolean getEnableSparkSinks();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
index 03db811..63dfe0d 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java
@@ -18,15 +18,16 @@
 
 package org.apache.beam.runners.spark;
 
+import java.util.Collection;
 import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly;
 import org.apache.beam.runners.spark.translation.EvaluationContext;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
-import org.apache.beam.runners.spark.translation.SparkPipelineEvaluator;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
 import org.apache.beam.runners.spark.translation.SparkProcessContext;
+import org.apache.beam.runners.spark.translation.TransformEvaluator;
 import org.apache.beam.runners.spark.translation.TransformTranslator;
+import org.apache.beam.runners.spark.translation.streaming.SparkRunnerStreamingContextFactory;
 import org.apache.beam.runners.spark.translation.streaming.StreamingEvaluationContext;
-import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator;
 import org.apache.beam.runners.spark.util.SinglePrimitiveOutputPTransform;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -34,15 +35,17 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformTreeNode;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
 import org.apache.spark.SparkException;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.api.java.JavaStreamingContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -143,40 +146,27 @@ public final class SparkRunner extends PipelineRunner<EvaluationResult> {
   public EvaluationResult run(Pipeline pipeline) {
     try {
       LOG.info("Executing pipeline using the SparkRunner.");
-      JavaSparkContext jsc;
-      if (mOptions.getUsesProvidedSparkContext()) {
-        LOG.info("Using a provided Spark Context");
-        jsc = mOptions.getProvidedSparkContext();
-        if (jsc == null || jsc.sc().isStopped()){
-          LOG.error("The provided Spark context "
-                  + jsc + " was not created or was stopped");
-          throw new RuntimeException("The provided Spark context was not created or was stopped");
-        }
-      } else {
-        LOG.info("Creating a new Spark Context");
-        jsc = SparkContextFactory.getSparkContext(mOptions.getSparkMaster(), mOptions.getAppName());
-      }
-      if (mOptions.isStreaming()) {
-        SparkPipelineTranslator translator =
-            new StreamingTransformTranslator.Translator(new TransformTranslator.Translator());
-        Duration batchInterval = new Duration(mOptions.getBatchIntervalMillis());
-        LOG.info("Setting Spark streaming batchInterval to {} msec", batchInterval.milliseconds());
 
-        EvaluationContext ctxt = createStreamingEvaluationContext(jsc, pipeline, batchInterval);
-        pipeline.traverseTopologically(new SparkPipelineEvaluator(ctxt, translator));
-        ctxt.computeOutputs();
+      if (mOptions.isStreaming()) {
+        SparkRunnerStreamingContextFactory contextFactory =
+            new SparkRunnerStreamingContextFactory(pipeline, mOptions);
+        JavaStreamingContext jssc = JavaStreamingContext.getOrCreate(mOptions.getCheckpointDir(),
+            contextFactory);
 
-        LOG.info("Streaming pipeline construction complete. Starting execution..");
-        ((StreamingEvaluationContext) ctxt).getStreamingContext().start();
+        LOG.info("Starting streaming pipeline execution.");
+        jssc.start();
 
-        return ctxt;
+        // if recovering from checkpoint, we have to reconstruct the EvaluationResult instance.
+        return contextFactory.getCtxt() == null ? new StreamingEvaluationContext(jssc.sc(),
+            pipeline, jssc, mOptions.getTimeout()) : contextFactory.getCtxt();
       } else {
         if (mOptions.getTimeout() > 0) {
           LOG.info("Timeout is ignored by the SparkRunner in batch.");
         }
+        JavaSparkContext jsc = SparkContextFactory.getSparkContext(mOptions);
         EvaluationContext ctxt = new EvaluationContext(jsc, pipeline);
         SparkPipelineTranslator translator = new TransformTranslator.Translator();
-        pipeline.traverseTopologically(new SparkPipelineEvaluator(ctxt, translator));
+        pipeline.traverseTopologically(new Evaluator(translator, ctxt));
         ctxt.computeOutputs();
 
         LOG.info("Pipeline execution complete.");
@@ -202,23 +192,18 @@ public final class SparkRunner extends PipelineRunner<EvaluationResult> {
     }
   }
 
-  private EvaluationContext
-      createStreamingEvaluationContext(JavaSparkContext jsc, Pipeline pipeline,
-      Duration batchDuration) {
-    JavaStreamingContext jssc = new JavaStreamingContext(jsc, batchDuration);
-    return new StreamingEvaluationContext(jsc, pipeline, jssc, mOptions.getTimeout());
-  }
-
   /**
    * Evaluator on the pipeline.
    */
-  public abstract static class Evaluator extends Pipeline.PipelineVisitor.Defaults {
-    protected static final Logger LOG = LoggerFactory.getLogger(Evaluator.class);
+  public static class Evaluator extends Pipeline.PipelineVisitor.Defaults {
+    private static final Logger LOG = LoggerFactory.getLogger(Evaluator.class);
 
-    protected final SparkPipelineTranslator translator;
+    private final EvaluationContext ctxt;
+    private final SparkPipelineTranslator translator;
 
-    protected Evaluator(SparkPipelineTranslator translator) {
+    public Evaluator(SparkPipelineTranslator translator, EvaluationContext ctxt) {
       this.translator = translator;
+      this.ctxt = ctxt;
     }
 
     @Override
@@ -242,8 +227,62 @@ public final class SparkRunner extends PipelineRunner<EvaluationResult> {
       doVisitTransform(node);
     }
 
-    protected abstract <TransformT extends PTransform<? super PInput, POutput>> void
-        doVisitTransform(TransformTreeNode node);
+    <TransformT extends PTransform<? super PInput, POutput>> void
+        doVisitTransform(TransformTreeNode node) {
+      @SuppressWarnings("unchecked")
+      TransformT transform = (TransformT) node.getTransform();
+      @SuppressWarnings("unchecked")
+      Class<TransformT> transformClass = (Class<TransformT>) (Class<?>) transform.getClass();
+      @SuppressWarnings("unchecked") TransformEvaluator<TransformT> evaluator =
+          translate(node, transform, transformClass);
+      LOG.info("Evaluating {}", transform);
+      AppliedPTransform<PInput, POutput, TransformT> appliedTransform =
+          AppliedPTransform.of(node.getFullName(), node.getInput(), node.getOutput(), transform);
+      ctxt.setCurrentTransform(appliedTransform);
+      evaluator.evaluate(transform, ctxt);
+      ctxt.setCurrentTransform(null);
+    }
+
+    /**
+     *  Determine if this Node belongs to a Bounded branch of the pipeline, or Unbounded, and
+     *  translate with the proper translator.
+     */
+    private <TransformT extends PTransform<? super PInput, POutput>> TransformEvaluator<TransformT>
+        translate(TransformTreeNode node, TransformT transform, Class<TransformT> transformClass) {
+      //--- determine if node is bounded/unbounded.
+      // usually, the input determines if the PCollection to apply the next transformation to
+      // is BOUNDED or UNBOUNDED, meaning RDD/DStream.
+      Collection<? extends PValue> pValues;
+      PInput pInput = node.getInput();
+      if (pInput instanceof PBegin) {
+        // in case of a PBegin, it's the output.
+        pValues = node.getOutput().expand();
+      } else {
+        pValues = pInput.expand();
+      }
+      PCollection.IsBounded isNodeBounded = isBoundedCollection(pValues);
+      // translate accordingly.
+      LOG.debug("Translating {} as {}", transform, isNodeBounded);
+      return isNodeBounded.equals(PCollection.IsBounded.BOUNDED)
+          ? translator.translateBounded(transformClass)
+              : translator.translateUnbounded(transformClass);
+    }
+
+    private PCollection.IsBounded isBoundedCollection(Collection<? extends PValue> pValues) {
+      // anything that is not a PCollection, is BOUNDED.
+      // For PCollections:
+      // BOUNDED behaves as the Identity Element, BOUNDED + BOUNDED = BOUNDED
+      // while BOUNDED + UNBOUNDED = UNBOUNDED.
+      PCollection.IsBounded isBounded = PCollection.IsBounded.BOUNDED;
+      for (PValue pValue: pValues) {
+        if (pValue instanceof PCollection) {
+          isBounded = isBounded.and(((PCollection) pValue).isBounded());
+        } else {
+          isBounded = isBounded.and(PCollection.IsBounded.BOUNDED);
+        }
+      }
+      return isBounded;
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java
new file mode 100644
index 0000000..758372e
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.aggregators;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * For resilience, {@link Accumulator}s are required to be wrapped in a Singleton.
+ * @see <a href="https://spark.apache.org/docs/1.6.2/streaming-programming-guide.html#accumulators-and-broadcast-variables">accumulators</a>
+ */
+public class AccumulatorSingleton {
+
+  private static volatile Accumulator<NamedAggregators> instance = null;
+
+  public static Accumulator<NamedAggregators> getInstance(JavaSparkContext jsc) {
+    if (instance == null) {
+      synchronized (AccumulatorSingleton.class) {
+        if (instance == null) {
+          //TODO: currently when recovering from checkpoint, Spark does not recover the
+          // last known Accumulator value. The SparkRunner should be able to persist and recover
+          // the NamedAggregators in order to recover Aggregators as well.
+          instance = jsc.sc().accumulator(new NamedAggregators(), new AggAccumParam());
+        }
+      }
+    }
+    return instance;
+  }
+
+  @VisibleForTesting
+  public static void clear() {
+    synchronized (AccumulatorSingleton.class) {
+      instance = null;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index 454b760..79639a2 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -22,10 +22,12 @@ import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.util.BroadcastHelper;
 import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.function.FlatMapFunction;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -37,8 +39,8 @@ import org.slf4j.LoggerFactory;
  * @param <OutputT> Output element type.
  */
 public class DoFnFunction<InputT, OutputT>
-    implements FlatMapFunction<Iterator<WindowedValue<InputT>>,
-    WindowedValue<OutputT>> {
+    implements FlatMapFunction<Iterator<WindowedValue<InputT>>, WindowedValue<OutputT>> {
+  private final Accumulator<NamedAggregators> accum;
   private final OldDoFn<InputT, OutputT> mFunction;
   private static final Logger LOG = LoggerFactory.getLogger(DoFnFunction.class);
 
@@ -46,18 +48,32 @@ public class DoFnFunction<InputT, OutputT>
   private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;
 
   /**
+   * @param accum      The Spark Accumulator that handles the Beam Aggregators.
    * @param fn         DoFunction to be wrapped.
    * @param runtime    Runtime to apply function in.
    * @param sideInputs Side inputs used in DoFunction.
    */
-  public DoFnFunction(OldDoFn<InputT, OutputT> fn,
-               SparkRuntimeContext runtime,
-               Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+  public DoFnFunction(Accumulator<NamedAggregators> accum,
+                      OldDoFn<InputT, OutputT> fn,
+                      SparkRuntimeContext runtime,
+                      Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+    this.accum = accum;
     this.mFunction = fn;
     this.mRuntimeContext = runtime;
     this.mSideInputs = sideInputs;
   }
 
+  /**
+   * @param fn         DoFunction to be wrapped.
+   * @param runtime    Runtime to apply function in.
+   * @param sideInputs Side inputs used in DoFunction.
+   */
+  public DoFnFunction(OldDoFn<InputT, OutputT> fn,
+                      SparkRuntimeContext runtime,
+                      Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+    this(null, fn, runtime, sideInputs);
+  }
+
   @Override
   public Iterable<WindowedValue<OutputT>> call(Iterator<WindowedValue<InputT>> iter) throws
       Exception {
@@ -103,6 +119,15 @@ public class DoFnFunction<InputT, OutputT>
     }
 
     @Override
+    public Accumulator<NamedAggregators> getAccumulator() {
+      if (accum == null) {
+        throw new UnsupportedOperationException("SparkRunner does not provide Aggregator support "
+             + "for DoFnFunction of type: " + mFunction.getClass().getCanonicalName());
+      }
+      return accum;
+    }
+
+    @Override
     protected void clearOutput() {
       outputs.clear();
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index 724f54f..2397276 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -29,6 +29,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import org.apache.beam.runners.spark.EvaluationResult;
+import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.sdk.AggregatorRetrievalException;
 import org.apache.beam.sdk.AggregatorValues;
@@ -68,7 +69,7 @@ public class EvaluationContext implements EvaluationResult {
   public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
     this.jsc = jsc;
     this.pipeline = pipeline;
-    this.runtime = new SparkRuntimeContext(jsc, pipeline);
+    this.runtime = new SparkRuntimeContext(pipeline, jsc);
   }
 
   /**
@@ -136,7 +137,7 @@ public class EvaluationContext implements EvaluationResult {
     return jsc;
   }
 
-  protected Pipeline getPipeline() {
+  public Pipeline getPipeline() {
     return pipeline;
   }
 
@@ -144,7 +145,7 @@ public class EvaluationContext implements EvaluationResult {
     return runtime;
   }
 
-  protected void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) {
+  public void setCurrentTransform(AppliedPTransform<?, ?, ?> transform) {
     this.currentTransform = transform;
   }
 
@@ -178,7 +179,7 @@ public class EvaluationContext implements EvaluationResult {
     pcollections.put((PValue) getOutput(transform), new RDDHolder<>(values, coder));
   }
 
-  void setPView(PValue view, Iterable<? extends WindowedValue<?>> value) {
+  public void setPView(PValue view, Iterable<? extends WindowedValue<?>> value) {
     pview.put(view, value);
   }
 
@@ -187,7 +188,7 @@ public class EvaluationContext implements EvaluationResult {
     return pcollections.containsKey(pvalue);
   }
 
-  protected JavaRDDLike<?, ?> getRDD(PValue pvalue) {
+  public JavaRDDLike<?, ?> getRDD(PValue pvalue) {
     RDDHolder<?> rddHolder = pcollections.get(pvalue);
     JavaRDDLike<?, ?> rdd = rddHolder.getRDD();
     leafRdds.remove(rddHolder);
@@ -211,7 +212,7 @@ public class EvaluationContext implements EvaluationResult {
     leafRdds.add(rddHolder);
   }
 
-  JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) {
+  protected JavaRDDLike<?, ?> getInputRDD(PTransform<? extends PInput, ?> transform) {
     return getRDD((PValue) getInput(transform));
   }
 
@@ -252,13 +253,13 @@ public class EvaluationContext implements EvaluationResult {
 
   @Override
   public <T> T getAggregatorValue(String named, Class<T> resultType) {
-    return runtime.getAggregatorValue(named, resultType);
+    return runtime.getAggregatorValue(AccumulatorSingleton.getInstance(jsc), named, resultType);
   }
 
   @Override
   public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator)
       throws AggregatorRetrievalException {
-    return runtime.getAggregatorValues(aggregator);
+    return runtime.getAggregatorValues(AccumulatorSingleton.getInstance(jsc), aggregator);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
new file mode 100644
index 0000000..eb4002e
--- /dev/null
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+
+import com.google.common.collect.Lists;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn;
+import org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly;
+import org.apache.beam.runners.core.SystemReduceFn;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
+import org.apache.beam.runners.spark.util.ByteArray;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.OldDoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.values.KV;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+
+import scala.Tuple2;
+
+/**
+ * A set of group/combine functions to apply to Spark {@link org.apache.spark.rdd.RDD}s.
+ */
+public class GroupCombineFunctions {
+
+  /***
+   * Apply {@link GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly} to a Spark RDD.
+   */
+  public static <K, V> JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupByKeyOnly(
+      JavaRDD<WindowedValue<KV<K, V>>> rdd, KvCoder<K, V> coder) {
+    final Coder<K> keyCoder = coder.getKeyCoder();
+    final Coder<V> valueCoder = coder.getValueCoder();
+    // Use coders to convert objects in the PCollection to byte arrays, so they
+    // can be transferred over the network for the shuffle.
+    return rdd.map(WindowingHelpers.<KV<K, V>>unwindowFunction())
+        .mapToPair(TranslationUtils.<K, V>toPairFunction())
+        .mapToPair(CoderHelpers.toByteFunction(keyCoder, valueCoder))
+        .groupByKey()
+        .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, valueCoder))
+        // empty windows are OK here, see GroupByKey#evaluateHelper in the SDK
+        .map(TranslationUtils.<K, Iterable<V>>fromPairFunction())
+        .map(WindowingHelpers.<KV<K, Iterable<V>>>windowFunction());
+  }
+
+  /***
+   * Apply {@link GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow} to a Spark RDD.
+   */
+  public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K, Iterable<V>>>>
+  groupAlsoByWindow(JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> rdd,
+                    GroupByKeyViaGroupByKeyOnly.GroupAlsoByWindow<K, V> transform,
+                    SparkRuntimeContext runtimeContext,
+                    Accumulator<NamedAggregators> accum,
+                    KvCoder<K, Iterable<WindowedValue<V>>> inputKvCoder) {
+    //--- coders.
+    Coder<Iterable<WindowedValue<V>>> inputValueCoder = inputKvCoder.getValueCoder();
+    IterableCoder<WindowedValue<V>> inputIterableValueCoder =
+        (IterableCoder<WindowedValue<V>>) inputValueCoder;
+    Coder<WindowedValue<V>> inputIterableElementCoder = inputIterableValueCoder.getElemCoder();
+    WindowedValue.WindowedValueCoder<V> inputIterableWindowedValueCoder =
+        (WindowedValue.WindowedValueCoder<V>) inputIterableElementCoder;
+    Coder<V> inputIterableElementValueCoder = inputIterableWindowedValueCoder.getValueCoder();
+
+    @SuppressWarnings("unchecked")
+    WindowingStrategy<?, W> windowingStrategy =
+        (WindowingStrategy<?, W>) transform.getWindowingStrategy();
+
+    // GroupAlsoByWindow current uses a dummy in-memory StateInternals
+    OldDoFn<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>> gabwDoFn =
+        new GroupAlsoByWindowsViaOutputBufferDoFn<K, V, Iterable<V>, W>(
+            windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<K>(),
+                SystemReduceFn.<K, V, W>buffering(inputIterableElementValueCoder));
+    return rdd.mapPartitions(new DoFnFunction<>(accum, gabwDoFn, runtimeContext, null));
+  }
+
+  /**
+   * Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} transformation.
+   */
+  public static <InputT, AccumT, OutputT> OutputT
+  combineGlobally(JavaRDD<WindowedValue<InputT>> rdd,
+                  final Combine.CombineFn<InputT, AccumT, OutputT> globally,
+                  final Coder<InputT> iCoder,
+                  final Coder<AccumT> aCoder) {
+    // Use coders to convert objects in the PCollection to byte arrays, so they
+    // can be transferred over the network for the shuffle.
+    JavaRDD<byte[]> inRddBytes = rdd.map(WindowingHelpers.<InputT>unwindowFunction()).map(
+        CoderHelpers.toByteFunction(iCoder));
+    /*AccumT*/ byte[] acc = inRddBytes.aggregate(
+        CoderHelpers.toByteArray(globally.createAccumulator(), aCoder),
+        new Function2</*AccumT*/ byte[], /*InputT*/ byte[], /*AccumT*/ byte[]>() {
+          @Override
+          public /*AccumT*/ byte[] call(/*AccumT*/ byte[] ab, /*InputT*/ byte[] ib)
+              throws Exception {
+            AccumT a = CoderHelpers.fromByteArray(ab, aCoder);
+            InputT i = CoderHelpers.fromByteArray(ib, iCoder);
+            return CoderHelpers.toByteArray(globally.addInput(a, i), aCoder);
+          }
+        },
+        new Function2</*AccumT*/ byte[], /*AccumT*/ byte[], /*AccumT*/ byte[]>() {
+          @Override
+          public /*AccumT*/ byte[] call(/*AccumT*/ byte[] a1b, /*AccumT*/ byte[] a2b)
+              throws Exception {
+            AccumT a1 = CoderHelpers.fromByteArray(a1b, aCoder);
+            AccumT a2 = CoderHelpers.fromByteArray(a2b, aCoder);
+            // don't use Guava's ImmutableList.of as values may be null
+            List<AccumT> accumulators = Collections.unmodifiableList(Arrays.asList(a1, a2));
+            AccumT merged = globally.mergeAccumulators(accumulators);
+            return CoderHelpers.toByteArray(merged, aCoder);
+          }
+        }
+    );
+    return globally.extractOutput(CoderHelpers.fromByteArray(acc, aCoder));
+  }
+
+  /**
+   * Apply a composite {@link org.apache.beam.sdk.transforms.Combine.PerKey} transformation.
+   */
+  public static <K, InputT, AccumT, OutputT> JavaRDD<WindowedValue<KV<K, OutputT>>>
+  combinePerKey(JavaRDD<WindowedValue<KV<K, InputT>>> rdd,
+                final Combine.KeyedCombineFn<K, InputT, AccumT, OutputT> keyed,
+                final WindowedValue.FullWindowedValueCoder<K> wkCoder,
+                final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder,
+                final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder) {
+    // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value,
+    // since the functions passed to combineByKey don't receive the associated key of each
+    // value, and we need to map back into methods in Combine.KeyedCombineFn, which each
+    // require the key in addition to the InputT's and AccumT's being merged/accumulated.
+    // Once Spark provides a way to include keys in the arguments of combine/merge functions,
+    // we won't need to duplicate the keys anymore.
+    // Key has to bw windowed in order to group by window as well
+    JavaPairRDD<WindowedValue<K>, WindowedValue<KV<K, InputT>>> inRddDuplicatedKeyPair =
+        rdd.flatMapToPair(
+            new PairFlatMapFunction<WindowedValue<KV<K, InputT>>, WindowedValue<K>,
+                WindowedValue<KV<K, InputT>>>() {
+              @Override
+              public Iterable<Tuple2<WindowedValue<K>, WindowedValue<KV<K, InputT>>>>
+              call(WindowedValue<KV<K, InputT>> kv) {
+                  List<Tuple2<WindowedValue<K>,
+                      WindowedValue<KV<K, InputT>>>> tuple2s =
+                      Lists.newArrayListWithCapacity(kv.getWindows().size());
+                  for (BoundedWindow boundedWindow: kv.getWindows()) {
+                    WindowedValue<K> wk = WindowedValue.of(kv.getValue().getKey(),
+                        boundedWindow.maxTimestamp(), boundedWindow, kv.getPane());
+                    tuple2s.add(new Tuple2<>(wk, kv));
+                  }
+                return tuple2s;
+              }
+            });
+    // Use coders to convert objects in the PCollection to byte arrays, so they
+    // can be transferred over the network for the shuffle.
+    JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair
+        .mapToPair(CoderHelpers.toByteFunction(wkCoder, wkviCoder));
+
+    // The output of combineByKey will be "AccumT" (accumulator)
+    // types rather than "OutputT" (final output types) since Combine.CombineFn
+    // only provides ways to merge VAs, and no way to merge VOs.
+    JavaPairRDD</*K*/ ByteArray, /*KV<K, AccumT>*/ byte[]> accumulatedBytes =
+        inRddDuplicatedKeyPairBytes.combineByKey(
+        new Function</*KV<K, InputT>*/ byte[], /*KV<K, AccumT>*/ byte[]>() {
+          @Override
+          public /*KV<K, AccumT>*/ byte[] call(/*KV<K, InputT>*/ byte[] input) {
+            WindowedValue<KV<K, InputT>> wkvi =
+                CoderHelpers.fromByteArray(input, wkviCoder);
+            AccumT va = keyed.createAccumulator(wkvi.getValue().getKey());
+            va = keyed.addInput(wkvi.getValue().getKey(), va, wkvi.getValue().getValue());
+            WindowedValue<KV<K, AccumT>> wkva =
+                WindowedValue.of(KV.of(wkvi.getValue().getKey(), va), wkvi.getTimestamp(),
+                wkvi.getWindows(), wkvi.getPane());
+            return CoderHelpers.toByteArray(wkva, wkvaCoder);
+          }
+        },
+        new Function2</*KV<K, AccumT>*/ byte[],
+            /*KV<K, InputT>*/ byte[],
+            /*KV<K, AccumT>*/ byte[]>() {
+          @Override
+          public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc,
+              /*KV<K, InputT>*/ byte[] input) {
+            WindowedValue<KV<K, AccumT>> wkva =
+                CoderHelpers.fromByteArray(acc, wkvaCoder);
+            WindowedValue<KV<K, InputT>> wkvi =
+                CoderHelpers.fromByteArray(input, wkviCoder);
+            AccumT va =
+                keyed.addInput(wkva.getValue().getKey(), wkva.getValue().getValue(),
+                wkvi.getValue().getValue());
+            wkva = WindowedValue.of(KV.of(wkva.getValue().getKey(), va), wkva.getTimestamp(),
+                wkva.getWindows(), wkva.getPane());
+            return CoderHelpers.toByteArray(wkva, wkvaCoder);
+          }
+        },
+        new Function2</*KV<K, AccumT>*/ byte[],
+            /*KV<K, AccumT>*/ byte[],
+            /*KV<K, AccumT>*/ byte[]>() {
+          @Override
+          public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc1,
+              /*KV<K, AccumT>*/ byte[] acc2) {
+            WindowedValue<KV<K, AccumT>> wkva1 =
+                CoderHelpers.fromByteArray(acc1, wkvaCoder);
+            WindowedValue<KV<K, AccumT>> wkva2 =
+                CoderHelpers.fromByteArray(acc2, wkvaCoder);
+            AccumT va = keyed.mergeAccumulators(wkva1.getValue().getKey(),
+                // don't use Guava's ImmutableList.of as values may be null
+                Collections.unmodifiableList(Arrays.asList(wkva1.getValue().getValue(),
+                wkva2.getValue().getValue())));
+            WindowedValue<KV<K, AccumT>> wkva =
+                WindowedValue.of(KV.of(wkva1.getValue().getKey(),
+                va), wkva1.getTimestamp(), wkva1.getWindows(), wkva1.getPane());
+            return CoderHelpers.toByteArray(wkva, wkvaCoder);
+          }
+        });
+
+    JavaPairRDD<WindowedValue<K>, WindowedValue<OutputT>> extracted = accumulatedBytes
+        .mapToPair(CoderHelpers.fromByteFunction(wkCoder, wkvaCoder))
+        .mapValues(new Function<WindowedValue<KV<K, AccumT>>, WindowedValue<OutputT>>() {
+              @Override
+              public WindowedValue<OutputT> call(WindowedValue<KV<K, AccumT>> acc) {
+                return WindowedValue.of(keyed.extractOutput(acc.getValue().getKey(),
+                    acc.getValue().getValue()), acc.getTimestamp(), acc.getWindows(),
+                        acc.getPane());
+              }
+            });
+    return extracted.map(TranslationUtils.<WindowedValue<K>,
+        WindowedValue<OutputT>>fromPairFunction()).map(
+            new Function<KV<WindowedValue<K>, WindowedValue<OutputT>>,
+                WindowedValue<KV<K, OutputT>>>() {
+              @Override
+              public WindowedValue<KV<K, OutputT>> call(KV<WindowedValue<K>,
+                  WindowedValue<OutputT>> kwvo) throws Exception {
+                WindowedValue<OutputT> wvo = kwvo.getValue();
+                KV<K, OutputT> kvo = KV.of(kwvo.getKey().getValue(), wvo.getValue());
+                return WindowedValue.of(kvo, wvo.getTimestamp(), wvo.getWindows(), wvo.getPane());
+              }
+            });
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 4c44ffd..163cf13 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -24,12 +24,15 @@ import com.google.common.collect.LinkedListMultimap;
 import com.google.common.collect.Multimap;
 import java.util.Iterator;
 import java.util.Map;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.util.BroadcastHelper;
 import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.joda.time.Instant;
+
 import scala.Tuple2;
 
 /**
@@ -39,24 +42,50 @@ import scala.Tuple2;
  * @param <InputT> Input type for DoFunction.
  * @param <OutputT> Output type for DoFunction.
  */
-class MultiDoFnFunction<InputT, OutputT>
-    implements PairFlatMapFunction<Iterator<WindowedValue<InputT>>, TupleTag<?>, WindowedValue<?>> {
+public class MultiDoFnFunction<InputT, OutputT>
+    implements PairFlatMapFunction<Iterator<WindowedValue<InputT>>, TupleTag<?>,
+        WindowedValue<?>> {
+  private final Accumulator<NamedAggregators> accum;
   private final OldDoFn<InputT, OutputT> mFunction;
   private final SparkRuntimeContext mRuntimeContext;
   private final TupleTag<OutputT> mMainOutputTag;
   private final Map<TupleTag<?>, BroadcastHelper<?>> mSideInputs;
 
-  MultiDoFnFunction(
+  /**
+   * @param accum          The Spark Accumulator that handles the Beam Aggregators.
+   * @param fn             DoFunction to be wrapped.
+   * @param runtimeContext Runtime to apply function in.
+   * @param mainOutputTag  The main output {@link TupleTag}.
+   * @param sideInputs     Side inputs used in DoFunction.
+   */
+  public MultiDoFnFunction(
+      Accumulator<NamedAggregators> accum,
       OldDoFn<InputT, OutputT> fn,
       SparkRuntimeContext runtimeContext,
       TupleTag<OutputT> mainOutputTag,
       Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+    this.accum = accum;
     this.mFunction = fn;
     this.mRuntimeContext = runtimeContext;
     this.mMainOutputTag = mainOutputTag;
     this.mSideInputs = sideInputs;
   }
 
+  /**
+   * @param fn             DoFunction to be wrapped.
+   * @param runtimeContext Runtime to apply function in.
+   * @param mainOutputTag  The main output {@link TupleTag}.
+   * @param sideInputs     Side inputs used in DoFunction.
+   */
+  public MultiDoFnFunction(
+      OldDoFn<InputT, OutputT> fn,
+      SparkRuntimeContext runtimeContext,
+      TupleTag<OutputT> mainOutputTag,
+      Map<TupleTag<?>, BroadcastHelper<?>> sideInputs) {
+    this(null, fn, runtimeContext, mainOutputTag, sideInputs);
+  }
+
+
   @Override
   public Iterable<Tuple2<TupleTag<?>, WindowedValue<?>>>
       call(Iterator<WindowedValue<InputT>> iter) throws Exception {
@@ -99,6 +128,15 @@ class MultiDoFnFunction<InputT, OutputT>
     }
 
     @Override
+    public Accumulator<NamedAggregators> getAccumulator() {
+      if (accum == null) {
+        throw new UnsupportedOperationException("SparkRunner does not provide Aggregator support "
+             + "for MultiDoFnFunction of type: " + mFunction.getClass().getCanonicalName());
+      }
+      return accum;
+    }
+
+    @Override
     protected void clearOutput() {
       outputs.clear();
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
index 0e7db9f..8127ddc 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
@@ -18,14 +18,18 @@
 
 package org.apache.beam.runners.spark.translation;
 
+import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.serializer.KryoSerializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * The Spark context factory.
  */
 public final class SparkContextFactory {
+  private static final Logger LOG = LoggerFactory.getLogger(SparkContextFactory.class);
 
   /**
    * If the property {@code beam.spark.test.reuseSparkContext} is set to
@@ -40,19 +44,20 @@ public final class SparkContextFactory {
   private SparkContextFactory() {
   }
 
-  public static synchronized JavaSparkContext getSparkContext(String master, String appName) {
-    if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT)) {
+  public static synchronized JavaSparkContext getSparkContext(SparkPipelineOptions options) {
+    // reuse should be ignored if the context is provided.
+    if (Boolean.getBoolean(TEST_REUSE_SPARK_CONTEXT) && !options.getUsesProvidedSparkContext()) {
       if (sparkContext == null) {
-        sparkContext = createSparkContext(master, appName);
-        sparkMaster = master;
-      } else if (!master.equals(sparkMaster)) {
+        sparkContext = createSparkContext(options);
+        sparkMaster = options.getSparkMaster();
+      } else if (!options.getSparkMaster().equals(sparkMaster)) {
         throw new IllegalArgumentException(String.format("Cannot reuse spark context "
-                + "with different spark master URL. Existing: %s, requested: %s.",
-            sparkMaster, master));
+            + "with different spark master URL. Existing: %s, requested: %s.",
+                sparkMaster, options.getSparkMaster()));
       }
       return sparkContext;
     } else {
-      return createSparkContext(master, appName);
+      return createSparkContext(options);
     }
   }
 
@@ -62,14 +67,25 @@ public final class SparkContextFactory {
     }
   }
 
-  private static JavaSparkContext createSparkContext(String master, String appName) {
-    SparkConf conf = new SparkConf();
-    if (!conf.contains("spark.master")) {
-      // set master if not set.
-      conf.setMaster(master);
+  private static JavaSparkContext createSparkContext(SparkPipelineOptions options) {
+    if (options.getUsesProvidedSparkContext()) {
+      LOG.info("Using a provided Spark Context");
+      JavaSparkContext jsc = options.getProvidedSparkContext();
+      if (jsc == null || jsc.sc().isStopped()){
+        LOG.error("The provided Spark context " + jsc + " was not created or was stopped");
+        throw new RuntimeException("The provided Spark context was not created or was stopped");
+      }
+      return jsc;
+    } else {
+      LOG.info("Creating a brand new Spark Context.");
+      SparkConf conf = new SparkConf();
+      if (!conf.contains("spark.master")) {
+        // set master if not set.
+        conf.setMaster(options.getSparkMaster());
+      }
+      conf.setAppName(options.getAppName());
+      conf.set("spark.serializer", KryoSerializer.class.getCanonicalName());
+      return new JavaSparkContext(conf);
     }
-    conf.setAppName(appName);
-    conf.set("spark.serializer", KryoSerializer.class.getCanonicalName());
-    return new JavaSparkContext(conf);
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java
deleted file mode 100644
index 02e8b3d..0000000
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineEvaluator.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.translation;
-
-import org.apache.beam.runners.spark.SparkRunner;
-import org.apache.beam.sdk.runners.TransformTreeNode;
-import org.apache.beam.sdk.transforms.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
-
-/**
- * Pipeline {@link SparkRunner.Evaluator} for Spark.
- */
-public final class SparkPipelineEvaluator extends SparkRunner.Evaluator {
-
-  private final EvaluationContext ctxt;
-
-  public SparkPipelineEvaluator(EvaluationContext ctxt, SparkPipelineTranslator translator) {
-    super(translator);
-    this.ctxt = ctxt;
-  }
-
-  @Override
-  protected <TransformT extends PTransform<? super PInput, POutput>>
-  void doVisitTransform(TransformTreeNode
-      node) {
-    @SuppressWarnings("unchecked")
-    TransformT transform = (TransformT) node.getTransform();
-    @SuppressWarnings("unchecked")
-    Class<TransformT> transformClass = (Class<TransformT>) (Class<?>) transform.getClass();
-    @SuppressWarnings("unchecked") TransformEvaluator<TransformT> evaluator =
-        (TransformEvaluator<TransformT>) translator.translate(transformClass);
-    LOG.info("Evaluating {}", transform);
-    AppliedPTransform<PInput, POutput, TransformT> appliedTransform =
-        AppliedPTransform.of(node.getFullName(), node.getInput(), node.getOutput(), transform);
-    ctxt.setCurrentTransform(appliedTransform);
-    evaluator.evaluate(transform, ctxt);
-    ctxt.setCurrentTransform(null);
-  }
-}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java
index 1f7ccf1..f77df5f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPipelineTranslator.java
@@ -27,5 +27,8 @@ public interface SparkPipelineTranslator {
   boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz);
 
   <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT>
-  translate(Class<TransformT> clazz);
+  translateBounded(Class<TransformT> clazz);
+
+  <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT>
+  translateUnbounded(Class<TransformT> clazz);
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index 566a272..fbaf5b8 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -24,6 +24,7 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.Map;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.util.BroadcastHelper;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -39,6 +40,7 @@ import org.apache.beam.sdk.util.state.InMemoryStateInternals;
 import org.apache.beam.sdk.util.state.StateInternals;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.Accumulator;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -109,11 +111,13 @@ public abstract class SparkProcessContext<InputT, OutputT, ValueT>
   @Override
   public <AggregatprInputT, AggregatorOutputT>
   Aggregator<AggregatprInputT, AggregatorOutputT> createAggregatorInternal(
-      String named,
-      Combine.CombineFn<AggregatprInputT, ?, AggregatorOutputT> combineFn) {
-    return mRuntimeContext.createAggregator(named, combineFn);
+          String named,
+          Combine.CombineFn<AggregatprInputT, ?, AggregatorOutputT> combineFn) {
+    return mRuntimeContext.createAggregator(getAccumulator(), named, combineFn);
   }
 
+  public abstract Accumulator<NamedAggregators> getAccumulator();
+
   @Override
   public InputT element() {
     return windowedValue.getValue();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0feb6499/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
index 4e4cd1a..94c1648 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java
@@ -20,17 +20,14 @@ package org.apache.beam.runners.spark.translation;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
-
 import com.google.common.collect.ImmutableList;
-
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
-
 import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.aggregators.AggAccumParam;
+import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.metrics.AggregatorMetricSource;
 import org.apache.beam.sdk.AggregatorValues;
@@ -56,11 +53,6 @@ import org.apache.spark.metrics.MetricsSystem;
  * data flow program is launched.
  */
 public class SparkRuntimeContext implements Serializable {
-  /**
-   * An accumulator that is a map from names to aggregators.
-   */
-  private final Accumulator<NamedAggregators> accum;
-
   private final String serializedPipelineOptions;
 
   /**
@@ -69,10 +61,9 @@ public class SparkRuntimeContext implements Serializable {
   private final Map<String, Aggregator<?, ?>> aggregators = new HashMap<>();
   private transient CoderRegistry coderRegistry;
 
-  SparkRuntimeContext(JavaSparkContext jsc, Pipeline pipeline) {
-    final SparkPipelineOptions opts = pipeline.getOptions().as(SparkPipelineOptions.class);
-    accum = registerMetrics(jsc, opts);
-    serializedPipelineOptions = serializePipelineOptions(opts);
+  SparkRuntimeContext(Pipeline pipeline, JavaSparkContext jsc) {
+    this.serializedPipelineOptions = serializePipelineOptions(pipeline.getOptions());
+    registerMetrics(pipeline.getOptions().as(SparkPipelineOptions.class), jsc);
   }
 
   private static String serializePipelineOptions(PipelineOptions pipelineOptions) {
@@ -91,10 +82,9 @@ public class SparkRuntimeContext implements Serializable {
     }
   }
 
-  private Accumulator<NamedAggregators> registerMetrics(final JavaSparkContext jsc,
-                                                        final SparkPipelineOptions opts) {
-    final NamedAggregators initialValue = new NamedAggregators();
-    final Accumulator<NamedAggregators> accum = jsc.accumulator(initialValue, new AggAccumParam());
+  private void registerMetrics(final SparkPipelineOptions opts, final JavaSparkContext jsc) {
+    final Accumulator<NamedAggregators> accum = AccumulatorSingleton.getInstance(jsc);
+    final NamedAggregators initialValue = accum.value();
 
     if (opts.getEnableSparkSinks()) {
       final MetricsSystem metricsSystem = SparkEnv$.MODULE$.get().metricsSystem();
@@ -104,26 +94,28 @@ public class SparkRuntimeContext implements Serializable {
       metricsSystem.removeSource(aggregatorMetricSource);
       metricsSystem.registerSource(aggregatorMetricSource);
     }
-
-    return accum;
   }
 
   /**
    * Retrieves corresponding value of an aggregator.
    *
+   * @param accum          The Spark Accumulator holding all Aggregators.
    * @param aggregatorName Name of the aggregator to retrieve the value of.
    * @param typeClass      Type class of value to be retrieved.
    * @param <T>            Type of object to be returned.
    * @return The value of the aggregator.
    */
-  public <T> T getAggregatorValue(String aggregatorName, Class<T> typeClass) {
+  public <T> T getAggregatorValue(Accumulator<NamedAggregators> accum,
+                                  String aggregatorName,
+                                  Class<T> typeClass) {
     return accum.value().getValue(aggregatorName, typeClass);
   }
 
-  public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) {
+  public <T> AggregatorValues<T> getAggregatorValues(Accumulator<NamedAggregators> accum,
+                                                     Aggregator<?, T> aggregator) {
     @SuppressWarnings("unchecked")
     Class<T> aggValueClass = (Class<T>) aggregator.getCombineFn().getOutputType().getRawType();
-    final T aggregatorValue = getAggregatorValue(aggregator.getName(), aggValueClass);
+    final T aggregatorValue = getAggregatorValue(accum, aggregator.getName(), aggValueClass);
     return new AggregatorValues<T>() {
       @Override
       public Collection<T> getValues() {
@@ -144,14 +136,16 @@ public class SparkRuntimeContext implements Serializable {
   /**
    * Creates and aggregator and associates it with the specified name.
    *
+   * @param accum     Spark Accumulator.
    * @param named     Name of aggregator.
    * @param combineFn Combine function used in aggregation.
-   * @param <InputT>      Type of inputs to aggregator.
-   * @param <InterT>   Intermediate data type
-   * @param <OutputT>     Type of aggregator outputs.
+   * @param <InputT>  Type of inputs to aggregator.
+   * @param <InterT>  Intermediate data type
+   * @param <OutputT> Type of aggregator outputs.
    * @return Specified aggregator
    */
   public synchronized <InputT, InterT, OutputT> Aggregator<InputT, OutputT> createAggregator(
+      Accumulator<NamedAggregators> accum,
       String named,
       Combine.CombineFn<? super InputT, InterT, OutputT> combineFn) {
     @SuppressWarnings("unchecked")


Mime
View raw message