beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [1/3] incubator-beam git commit: [BEAM-1133] Add maxNumRecords per micro-batch for Spark runner options.
Date Mon, 12 Dec 2016 19:38:34 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/master 321547fb1 -> bfd21d72f


[BEAM-1133] Add maxNumRecords per micro-batch for Spark runner options.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/6b14ce53
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/6b14ce53
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/6b14ce53

Branch: refs/heads/master
Commit: 6b14ce538c52d26c3b6a5db3b8b1f603216b21d8
Parents: 321547f
Author: Sela <ansela@paypal.com>
Authored: Mon Dec 12 12:37:32 2016 +0200
Committer: Kenneth Knowles <klk@google.com>
Committed: Mon Dec 12 11:37:15 2016 -0800

----------------------------------------------------------------------
 .../runners/spark/SparkPipelineOptions.java     |  5 +++++
 .../beam/runners/spark/io/SourceDStream.java    | 21 +++++++++++++++-----
 .../runners/spark/io/SparkUnboundedSource.java  | 17 ++++++++++------
 3 files changed, 32 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6b14ce53/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
index 3f8b379..a2cd887 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
@@ -54,6 +54,11 @@ public interface SparkPipelineOptions
   Long getMinReadTimeMillis();
   void setMinReadTimeMillis(Long minReadTimeMillis);
 
+  @Description("Max records per micro-batch. For streaming sources only.")
+  @Default.Long(-1)
+  Long getMaxRecordsPerBatch();
+  void setMaxRecordsPerBatch(Long maxRecordsPerBatch);
+
   @Description("A value between 0-1 to describe the percentage of a micro-batch dedicated
"
       + "to reading from UnboundedSource.")
   @Default.Double(0.1)

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6b14ce53/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
index 84b247b..8a0763b 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
@@ -53,7 +53,7 @@ import scala.Tuple2;
  * {@link SparkPipelineOptions#getMinReadTimeMillis()}.
  * Records bound is controlled by the {@link RateController} mechanism.
  */
-public class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
+class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark>
       extends InputDStream<Tuple2<Source<T>, CheckpointMarkT>> {
   private static final Logger LOG = LoggerFactory.getLogger(SourceDStream.class);
 
@@ -64,10 +64,16 @@ public class SourceDStream<T, CheckpointMarkT extends UnboundedSource.Checkpoint
   // in case of resuming/recovering from checkpoint, the DStream will be reconstructed and
this
   // property should not be reset.
   private final int initialParallelism;
+  // the bound on max records is optional.
+  // in case it is set explicitly via PipelineOptions, it takes precedence
+  // otherwise it could be activated via RateController.
+  private Long boundMaxRecords = null;
+
+  SourceDStream(
+      StreamingContext ssc,
+      UnboundedSource<T, CheckpointMarkT> unboundedSource,
+      SparkRuntimeContext runtimeContext) {
 
-  public SourceDStream(StreamingContext ssc,
-                       UnboundedSource<T, CheckpointMarkT> unboundedSource,
-                       SparkRuntimeContext runtimeContext) {
     super(ssc, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag());
     this.unboundedSource = unboundedSource;
     this.runtimeContext = runtimeContext;
@@ -80,10 +86,15 @@ public class SourceDStream<T, CheckpointMarkT extends UnboundedSource.Checkpoint
     checkArgument(this.initialParallelism > 0, "Number of partitions must be greater than
zero.");
   }
 
+  public void setMaxRecordsPerBatch(long maxRecordsPerBatch) {
+    boundMaxRecords = maxRecordsPerBatch;
+  }
+
   @Override
   public scala.Option<RDD<Tuple2<Source<T>, CheckpointMarkT>>> compute(Time
validTime) {
+    long maxNumRecords = boundMaxRecords != null ? boundMaxRecords : rateControlledMaxRecords();
     MicrobatchSource<T, CheckpointMarkT> microbatchSource = new MicrobatchSource<>(
-        unboundedSource, boundReadDuration, initialParallelism, rateControlledMaxRecords(),
-1,
+        unboundedSource, boundReadDuration, initialParallelism, maxNumRecords, -1,
         id());
     RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd = new SourceRDD.Unbounded<>(
         ssc().sc(), runtimeContext, microbatchSource);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6b14ce53/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
index b12098d..394b023 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
@@ -61,19 +61,25 @@ public class SparkUnboundedSource {
   JavaDStream<WindowedValue<T>> read(JavaStreamingContext jssc,
                                      SparkRuntimeContext rc,
                                      UnboundedSource<T, CheckpointMarkT> source) {
+    SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class);
+    Long maxRecordsPerBatch = options.getMaxRecordsPerBatch();
+    SourceDStream<T, CheckpointMarkT> sourceDStream = new SourceDStream<>(jssc.ssc(),
source, rc);
+    // if max records per batch was set by the user.
+    if (maxRecordsPerBatch > 0) {
+      sourceDStream.setMaxRecordsPerBatch(maxRecordsPerBatch);
+    }
     JavaPairInputDStream<Source<T>, CheckpointMarkT> inputDStream =
-        JavaPairInputDStream$.MODULE$.fromInputDStream(new SourceDStream<>(jssc.ssc(),
source, rc),
+        JavaPairInputDStream$.MODULE$.fromInputDStream(sourceDStream,
             JavaSparkContext$.MODULE$.<Source<T>>fakeClassTag(),
                 JavaSparkContext$.MODULE$.<CheckpointMarkT>fakeClassTag());
 
     // call mapWithState to read from a checkpointable sources.
-    //TODO: consider broadcasting the rc instead of re-sending every batch.
     JavaMapWithStateDStream<Source<T>, CheckpointMarkT, byte[],
         Iterator<WindowedValue<T>>> mapWithStateDStream = inputDStream.mapWithState(
             StateSpec.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc)));
 
     // set checkpoint duration for read stream, if set.
-    checkpointStream(mapWithStateDStream, rc);
+    checkpointStream(mapWithStateDStream, options);
     // flatmap and report read elements. Use the inputDStream's id to tie between the reported
     // info and the inputDStream it originated from.
     int id = inputDStream.inputDStream().id();
@@ -97,9 +103,8 @@ public class SparkUnboundedSource {
   }
 
   private static void checkpointStream(JavaDStream<?> dStream,
-                                       SparkRuntimeContext rc) {
-    long checkpointDurationMillis = rc.getPipelineOptions().as(SparkPipelineOptions.class)
-        .getCheckpointDurationMillis();
+                                       SparkPipelineOptions options) {
+    long checkpointDurationMillis = options.getCheckpointDurationMillis();
     if (checkpointDurationMillis > 0) {
       dStream.checkpoint(new Duration(checkpointDurationMillis));
     }


Mime
View raw message