beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rober...@apache.org
Subject [beam] branch master updated: [BEAM-3837] Handle BundleSplitRequests in Python SDK Harness.
Date Mon, 11 Feb 2019 10:06:35 GMT
This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bfc37eb  [BEAM-3837] Handle BundleSplitRequests in Python SDK Harness.
     new f45d674  Merge pull request #7759 [BEAM-3837] Handle BundleSplitRequests in Python
SDK Harness.
bfc37eb is described below

commit bfc37ebd6e6940858b789750d07f4e4b654a6e99
Author: Robert Bradshaw <robertwb@google.com>
AuthorDate: Mon Jan 28 16:07:21 2019 +0100

    [BEAM-3837] Handle BundleSplitRequests in Python SDK Harness.
---
 .../fn-execution/src/main/proto/beam_fn_api.proto  |  57 ++++++++
 sdks/python/apache_beam/io/restriction_trackers.py |  19 ++-
 sdks/python/apache_beam/runners/common.py          |  17 +++
 .../runners/portability/fn_api_runner.py           | 144 ++++++++++++++++++---
 .../apache_beam/runners/worker/bundle_processor.py |  70 +++++++++-
 .../apache_beam/runners/worker/operations.pxd      |   4 +
 .../apache_beam/runners/worker/operations.py       |  56 +++++++-
 .../apache_beam/runners/worker/sdk_worker.py       |  14 ++
 8 files changed, 355 insertions(+), 26 deletions(-)

diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index d681bc1..9a1c86a 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -746,6 +746,31 @@ message ProcessBundleSplitRequest {
   // If the backlog is unspecified for a PTransform, the runner would like
   // the SDK to process all data received for that PTransform.
   map<string, bytes> backlog_remaining = 2;
+
+  // A message specifying the desired split for a single transform.
+  message DesiredSplit {
+    // (Required) The fraction of known work remaining in this bundle
+    // for this transform that should be kept by the SDK after this split.
+    //
+    // Set to 0 to "checkpoint" as soon as possible (keeping as little work as
+    // possible and returning the remainder).
+    float fraction_of_remainder = 1;
+
+    // (Required for GrpcRead operations) Number of total elements expected
+    // to be sent to this GrpcRead operation, required to correctly account
+    // for unreceived data when determining where to split.
+    int64 estimated_input_elements = 2;
+
+    // TODO(SDF): Allow providing weights rather than sizes.
+    // TODO(SDF): Allow specifying allowed/preferred split points.
+  }
+
+  // (Required) Specifies the desired split for each transform.
+  //
+  // Currently only splits at GRPC read operations are supported.
+  // This may, of course, limit the amount of work downstream operations
+  // receive.
+  map<string, DesiredSplit> desired_splits = 3;
 }
 
 // Represents a partition of the bundle: a "primary" and
@@ -765,8 +790,40 @@ message ProcessBundleSplitResponse {
   // have to be executed in a separate bundle (e.g. in parallel on a different
   // worker, or after the current bundle completes, etc.)
   repeated DelayedBundleApplication residual_roots = 2;
+
+  // Represents contiguous portions of the data channel that are either
+  // entirely processed or entirely unprocessed and belong to the primary
+  // or residual respectively.
+  //
+  // This affords both a more efficient representation over the FnAPI
+  // (if the bundle is large) and often a more efficient representation
+  // on the runner side (e.g. if the set of elements can be represented
+  // as some range in an underlying dataset).
+  message ChannelSplit {
+    // (Required) The grpc read transform reading this channel.
+    string ptransform_id = 1;
+
+    // (Required) Name of the transform's input to which to pass the element.
+    string input_id = 2;
+
+    // The last element of the input channel that should be entirely considered
+    // part of the primary, identified by its absolute index in the (ordered)
+    // channel.
+    int32 last_primary_element = 3;
+
+    // The first element of the input channel that should be entirely considered
+    // part of the residual, identified by its absolute index in the (ordered)
+    // channel.
+    int32 first_residual_element = 4;
+  }
+
+  // Partitions of input data channels into primary and residual elements,
+  // if any. Should not include any elements represented in the bundle
+  // applications roots above.
+  repeated ChannelSplit channel_splits = 3;
 }
 
+
 message FinalizeBundleRequest {
   // (Required) A reference to a completed process bundle request with the given
   // instruction id.
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index e72d508..c165d95 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -79,6 +79,7 @@ class OffsetRestrictionTracker(RestrictionTracker):
   def __init__(self, start_position, stop_position):
     self._range = OffsetRange(start_position, stop_position)
     self._current_position = None
+    self._current_watermark = None
     self._last_claim_attempt = None
     self._deferred_residual = None
     self._checkpointed = False
@@ -98,6 +99,9 @@ class OffsetRestrictionTracker(RestrictionTracker):
     with self._lock:
       return (self._range.start, self._range.stop)
 
+  def current_watermark(self):
+    return self._current_watermark
+
   def start_position(self):
     with self._lock:
       return self._range.start
@@ -127,6 +131,19 @@ class OffsetRestrictionTracker(RestrictionTracker):
 
       return False
 
+  def try_split(self, fraction):
+    with self._lock:
+      if not self._checkpointed:
+        if self._current_position is None:
+          cur = self._range.start - 1
+        else:
+          cur = self._current_position
+        split_point = cur + int(max(1, (self._range.stop - cur) * fraction))
+        if split_point < self._range.stop:
+          prev_stop, self._range.stop = self._range.stop, split_point
+          return (self._range.start, split_point), (split_point, prev_stop)
+
+  # TODO(SDF): Replace all calls with try_claim(0).
   def checkpoint(self):
     with self._lock:
       # If self._current_position is 'None' no records have been claimed so
@@ -143,7 +160,7 @@ class OffsetRestrictionTracker(RestrictionTracker):
 
   def defer_remainder(self, watermark=None):
     with self._lock:
-      self._deferred_watermark = watermark
+      self._deferred_watermark = watermark or self._current_watermark
       self._deferred_residual = self.checkpoint()
 
   def deferred_status(self):
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 3d9b07f..efdb59f 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -633,6 +633,20 @@ class PerWindowInvoker(DoFnInvoker):
                 (windowed_value.value, deferred_restriction)),
             deferred_watermark)
 
+  def try_split(self, fraction):
+    restriction_tracker = self.restriction_tracker
+    current_windowed_value = self.current_windowed_value
+    if restriction_tracker and current_windowed_value:
+      split = restriction_tracker.try_split(fraction)
+      if split:
+        primary, residual = split
+        element = self.current_windowed_value.value
+        return (
+            (self.current_windowed_value.with_value((element, primary)),
+             None),
+            (self.current_windowed_value.with_value((element, residual)),
+             restriction_tracker.current_watermark()))
+
 
 class DoFnRunner(Receiver):
   """For internal use only; no backwards-compatibility guarantees.
@@ -721,6 +735,9 @@ class DoFnRunner(Receiver):
         restriction_tracker=self.do_fn_invoker.invoke_create_tracker(
             restriction))
 
+  def try_split(self, fraction):
+    return self.do_fn_invoker.try_split(fraction)
+
   def process_user_timer(self, timer_spec, key, window, timestamp):
     try:
       self.do_fn_invoker.invoke_user_timer(timer_spec, key, window, timestamp)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index e908a5c..5f8fa3b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -495,11 +495,21 @@ class FnApiRunner(runner.PipelineRunner):
       finally:
         controller.state.restore()
 
-    result = BundleManager(
+    result, splits = BundleManager(
         controller, get_buffer, process_bundle_descriptor,
         self._progress_frequency).process_bundle(
             data_input, data_output)
 
+    def input_for(ptransform_id, input_id):
+      input_pcoll = process_bundle_descriptor.transforms[
+          ptransform_id].inputs[input_id]
+      for read_id, proto in process_bundle_descriptor.transforms.items():
+        if (proto.spec.urn == bundle_processor.DATA_INPUT_URN
+            and input_pcoll in proto.outputs.values()):
+          return read_id, 'out'
+      raise RuntimeError(
+          'No IO transform feeds %s' % ptransform_id)
+
     last_result = result
     while True:
       deferred_inputs = collections.defaultdict(list)
@@ -530,21 +540,60 @@ class FnApiRunner(runner.PipelineRunner):
           deferred_inputs[transform_id, 'out'] = [out.get()]
           written_timers[:] = []
 
-      # Queue any delayed bundle applications.
+      # Queue any process-initiated delayed bundle applications.
       for delayed_application in last_result.process_bundle.residual_roots:
-        # Find the io transform that feeds this transform.
-        # TODO(SDF): Memoize?
-        application = delayed_application.application
-        input_pcoll = process_bundle_descriptor.transforms[
-            application.ptransform_id].inputs[application.input_id]
-        for input_id, proto in process_bundle_descriptor.transforms.items():
-          if (proto.spec.urn == bundle_processor.DATA_INPUT_URN
-              and input_pcoll in proto.outputs.values()):
-            deferred_inputs[input_id, 'out'].append(application.element)
-            break
-        else:
-          raise RuntimeError(
-              'No IO transform feeds %s' % application.ptransform_id)
+        deferred_inputs[
+            input_for(
+                delayed_application.application.ptransform_id,
+                delayed_application.application.input_id)
+        ].append(delayed_application.application.element)
+
+      # Queue any runner-initiated delayed bundle applications.
+      prev_stops = collections.defaultdict(lambda: float('inf'))
+      for split in splits:
+        for delayed_application in split.residual_roots:
+          deferred_inputs[
+              input_for(
+                  delayed_application.application.ptransform_id,
+                  delayed_application.application.input_id)
+          ].append(delayed_application.application.element)
+        for channel_split in split.channel_splits:
+          transform = process_bundle_descriptor.transforms[
+              channel_split.ptransform_id]
+          coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString(
+              transform.spec.payload).coder_id
+          coder_impl = context.coders[safe_coders[coder_id]].get_impl()
+          # TODO(SDF): This requires determanistic ordering of buffer iteration.
+          # TODO(SDF): The return split is in terms of indices.  Ideally,
+          # a runner could map these back to actual positions to effectively
+          # describe the two "halves" of the now-split range.  Even if we have
+          # to buffer each element we send (or at the very least a bit of
+          # metadata, like position, about each of them) this should be doable
+          # if they're already in memory and we are bounding the buffer size
+          # (e.g. to 10mb plus whatever is eagerly read from the SDK).  In the
+          # case of non-split-points, we can either immediately replay the
+          # "non-split-position" elements or record them as we do the other
+          # delayed applications.
+
+          # Decode and recode to split the encoded buffer by element index.
+          buffer = data_input[
+              channel_split.ptransform_id, channel_split.input_id]
+          input_stream = create_InputStream(''.join(buffer))
+          output_stream = create_OutputStream()
+          index = 0
+          prev_stop = prev_stops[channel_split.ptransform_id]
+          while input_stream.size() > 0:
+            if index > prev_stop:
+              break
+            element = coder_impl.decode_from_stream(input_stream, True)
+            if index >= channel_split.first_residual_element:
+              coder_impl.encode_to_stream(element, output_stream, True)
+            index += 1
+          deferred_inputs[
+              channel_split.ptransform_id, channel_split.input_id].append(
+                  output_stream.get())
+          prev_stops[
+              channel_split.ptransform_id] = channel_split.last_primary_element
 
       if deferred_inputs:
         # The worker will be waiting on these inputs as well.
@@ -552,7 +601,7 @@ class FnApiRunner(runner.PipelineRunner):
           if other_input not in deferred_inputs:
             deferred_inputs[other_input] = []
         # TODO(robertwb): merge results
-        last_result = BundleManager(
+        last_result, splits = BundleManager(
             controller,
             get_buffer,
             process_bundle_descriptor,
@@ -1083,7 +1132,7 @@ class BundleManager(object):
     self._registered = skip_registration
     self._progress_frequency = progress_frequency
 
-  def process_bundle(self, inputs, expected_outputs):
+  def process_bundle(self, inputs, expected_outputs, test_splits=False):
     # Unique id for the instruction processing this bundle.
     BundleManager._uid_counter += 1
     process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
@@ -1108,6 +1157,17 @@ class BundleManager(object):
         data_out.write(element_data)
       data_out.close()
 
+    # TODO(robertwb): Control this via a pipeline option.
+    if test_splits:
+      # Inject some splits.
+      random_splitter = BundleSplitter(
+          self._controller,
+          process_bundle_id,
+          self._bundle_descriptor.transforms.keys())
+      random_splitter.start()
+    else:
+      random_splitter = None
+
     # Actually start the bundle.
     if registration_future and registration_future.get().error:
       raise RuntimeError(registration_future.get().error)
@@ -1138,9 +1198,16 @@ class BundleManager(object):
       logging.debug('Wait for the bundle to finish.')
       result = result_future.get()
 
+    if random_splitter:
+      random_splitter.stop()
+      split_results = random_splitter.split_results()
+    else:
+      split_results = []
+
     if result.error:
       raise RuntimeError(result.error)
-    return result
+
+    return result, split_results
 
 
 class ProgressRequester(threading.Thread):
@@ -1181,6 +1248,47 @@ class ProgressRequester(threading.Thread):
     self._done = True
 
 
+class BundleSplitter(threading.Thread):
+  def __init__(self, controller, instruction_id, split_transforms,
+               frequency=.03, split_fractions=(.5, .25, 0)):
+    super(BundleSplitter, self).__init__()
+    self._controller = controller
+    self._instruction_id = instruction_id
+    self._split_transforms = split_transforms
+    self._split_fractions = split_fractions
+    self._frequency = frequency
+    self._results = []
+    self._done = False
+
+  def run(self):
+    for fraction in self._split_fractions:
+      if self._done:
+        return
+      split_result = self._controller.control_handler.push(
+          beam_fn_api_pb2.InstructionRequest(
+              process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest(
+                  instruction_reference=self._instruction_id,
+                  desired_splits={
+                      transform_id:
+                      beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit(
+                          fraction_of_remainder=fraction)
+                      for transform_id in self._split_transforms}))).get()
+      if split_result.error:
+        logging.info('Unable to split at %s: %s' % (
+            fraction, split_result.error))
+      elif split_result.process_bundle_split:
+        self._results.append(split_result.process_bundle_split)
+      time.sleep(self._frequency)
+
+  def split_results(self):
+    self.stop()
+    self.join()
+    return self._results
+
+  def stop(self):
+    self._done = True
+
+
 class ControlFuture(object):
   def __init__(self, instruction_id, response=None):
     self.instruction_id = instruction_id
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 6cffc02..db2d790 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -27,6 +27,7 @@ import json
 import logging
 import random
 import re
+import threading
 from builtins import next
 from builtins import object
 
@@ -113,9 +114,15 @@ class DataInputOperation(RunnerIOOperation):
         windowed_coder, target=input_target, data_channel=data_channel)
     # We must do this manually as we don't have a spec or spec.output_coders.
     self.receivers = [
-        operations.ConsumerSet(
+        operations.ConsumerSet.create(
             self.counter_factory, self.name_context.step_name, 0,
             next(iter(itervalues(consumers))), self.windowed_coder)]
+    self.splitting_lock = threading.Lock()
+
+  def start(self):
+    super(DataInputOperation, self).start()
+    self.index = -1
+    self.stop = float('inf')
 
   def process(self, windowed_value):
     self.output(windowed_value)
@@ -123,10 +130,38 @@ class DataInputOperation(RunnerIOOperation):
   def process_encoded(self, encoded_windowed_values):
     input_stream = coder_impl.create_InputStream(encoded_windowed_values)
     while input_stream.size() > 0:
+      with self.splitting_lock:
+        if self.index == self.stop - 1:
+          return
+        self.index += 1
       decoded_value = self.windowed_coder_impl.decode_from_stream(
           input_stream, True)
       self.output(decoded_value)
 
+  def try_split(self, fraction_of_remainder, total_buffer_size=None):
+    with self.splitting_lock:
+      # If total_buffer_size is not provided, pick something.
+      if not total_buffer_size:
+        total_buffer_size = self.index + 2
+      elif self.stop and total_buffer_size > self.stop:
+        total_buffer_size = self.stop
+      # Compute, as a fraction, how much further to go.
+      # TODO(SDF): Take into account progress of current element.
+      stop_offset = (total_buffer_size - self.index) * fraction_of_remainder
+      # If it's less than a whole element, try splitting the current element.
+      if int(stop_offset) == 0:
+        split = self.receivers[0].try_split(stop_offset)
+        if split:
+          element_primary, element_residual = split
+          self.stop = self.index + 1
+          return self.stop - 2, element_primary, element_residual, self.stop
+
+      # Otherwise, split at the closest element boundary.
+      desired_stop = max(int(stop_offset), 1) + self.index
+      if desired_stop < self.stop:
+        self.stop = desired_stop
+        return self.stop - 1, None, None, self.stop
+
 
 class _StateBackedIterable(object):
   def __init__(self, state_handler, state_key, coder_or_impl):
@@ -413,6 +448,7 @@ class BundleProcessor(object):
     self.ops = self.create_execution_tree(self.process_bundle_descriptor)
     for op in self.ops.values():
       op.setup()
+    self.splitting_lock = threading.Lock()
 
   def create_execution_tree(self, descriptor):
 
@@ -509,8 +545,40 @@ class BundleProcessor(object):
           for op, residual in execution_context.delayed_applications]
 
     finally:
+      # Ensure any in-flight split attempts complete.
+      with self.splitting_lock:
+        pass
       self.state_sampler.stop_if_still_running()
 
+  def try_split(self, bundle_split_request):
+    split_response = beam_fn_api_pb2.ProcessBundleSplitResponse()
+    with self.splitting_lock:
+      for op in self.ops.values():
+        if isinstance(op, DataInputOperation):
+          desired_split = bundle_split_request.desired_splits.get(
+              op.target.primitive_transform_reference)
+          if desired_split:
+            split = op.try_split(desired_split.fraction_of_remainder,
+                                 desired_split.estimated_input_elements)
+            if split:
+              (primary_end, element_primary, element_residual, residual_start,
+              ) = split
+              if element_primary:
+                split_response.primary_roots.add().CopyFrom(
+                    self.delayed_bundle_application(
+                        *element_primary).application)
+              if element_residual:
+                split_response.residual_roots.add().CopyFrom(
+                    self.delayed_bundle_application(*element_residual))
+              split_response.channel_splits.extend([
+                  beam_fn_api_pb2.ProcessBundleSplitResponse.ChannelSplit(
+                      ptransform_id=op.target.primitive_transform_reference,
+                      input_id=op.target.name,
+                      last_primary_element=primary_end,
+                      first_residual_element=residual_start)])
+
+    return split_response
+
   def delayed_bundle_application(self, op, deferred_remainder):
     ptransform_id, main_input_tag, main_input_coder, outputs = op.input_info
     # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
index 10c3c41..9f0c015 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -38,6 +38,10 @@ cdef class ConsumerSet(Receiver):
   cpdef update_counters_finish(self)
 
 
+cdef class SingletonConsumerSet(ConsumerSet):
+  cdef Operation consumer
+
+
 cdef class Operation(object):
   cdef readonly name_context
   cdef readonly operation_name
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
index a6e0c31..c7c767f 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -73,6 +73,14 @@ class ConsumerSet(Receiver):
   the other edge.
   ConsumerSet are attached to the outputting Operation.
   """
+  @staticmethod
+  def create(counter_factory, step_name, output_index, consumers, coder):
+    if len(consumers) == 1:
+      return SingletonConsumerSet(
+          counter_factory, step_name, output_index, consumers, coder)
+    else:
+      return ConsumerSet(
+          counter_factory, step_name, output_index, consumers, coder)
 
   def __init__(
       self, counter_factory, step_name, output_index, consumers, coder):
@@ -90,6 +98,14 @@ class ConsumerSet(Receiver):
       cython.cast(Operation, consumer).process(windowed_value)
     self.update_counters_finish()
 
+  def try_split(self, fraction_of_remainder):
+    # TODO(SDF): Consider supporting splitting each consumer individually.
+    # This would never come up in the existing SDF expansion, but might
+    # be useful to support fused SDF nodes.
+    # This would require dedicated delivery of the split results to each
+    # of the consumers separately.
+    return None
+
   def update_counters_start(self, windowed_value):
     self.opcounter.update_from(windowed_value)
 
@@ -102,6 +118,23 @@ class ConsumerSet(Receiver):
         len(self.consumers))
 
 
+class SingletonConsumerSet(ConsumerSet):
+  def __init__(
+      self, counter_factory, step_name, output_index, consumers, coder):
+    assert len(consumers) == 1
+    super(SingletonConsumerSet, self).__init__(
+        counter_factory, step_name, output_index, consumers, coder)
+    self.consumer = consumers[0]
+
+  def receive(self, windowed_value):
+    self.update_counters_start(windowed_value)
+    self.consumer.process(windowed_value)
+    self.update_counters_finish()
+
+  def try_split(self, fraction_of_remainder):
+    return self.consumer.try_split(fraction_of_remainder)
+
+
 class Operation(object):
   """An operation representing the live version of a work item specification.
 
@@ -157,11 +190,13 @@ class Operation(object):
       # top-level operation, should have output_coders
       #TODO(pabloem): Define better what step name is used here.
       if getattr(self.spec, 'output_coders', None):
-        self.receivers = [ConsumerSet(self.counter_factory,
-                                      self.name_context.logging_name(),
-                                      i,
-                                      self.consumers[i], coder)
-                          for i, coder in enumerate(self.spec.output_coders)]
+        self.receivers = [
+            ConsumerSet.create(
+                self.counter_factory,
+                self.name_context.logging_name(),
+                i,
+                self.consumers[i], coder)
+            for i, coder in enumerate(self.spec.output_coders)]
     self.setup_done = True
 
   def start(self):
@@ -174,6 +209,9 @@ class Operation(object):
     """Process element in operation."""
     pass
 
+  def try_split(self, fraction_of_remainder):
+    return None
+
   def finish(self):
     """Finish operation."""
     pass
@@ -327,7 +365,7 @@ class ImpulseReadOperation(Operation):
         name_context, None, counter_factory, state_sampler)
     self.source = source
     self.receivers = [
-        ConsumerSet(
+        ConsumerSet.create(
             self.counter_factory, self.name_context.step_name, 0,
             next(iter(consumers.values())), output_coder)]
 
@@ -553,6 +591,12 @@ class SdfProcessElements(DoOperation):
         self.execution_context.delayed_applications.append(
             (self, delayed_application))
 
+  def try_split(self, fraction_of_remainder):
+    split = self.dofn_runner.try_split(fraction_of_remainder)
+    if split:
+      primary, residual = split
+      return (self, primary), (self, residual)
+
 
 class DoFnRunnerReceiver(Receiver):
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 6067181..1528d23 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -197,7 +197,13 @@ class SdkHarness(object):
     logging.debug(
         "Currently using %s threads." % len(self._process_thread_pool._threads))
 
+  def _request_process_bundle_split(self, request):
+    self._request_process_bundle_action(request)
+
   def _request_process_bundle_progress(self, request):
+    self._request_process_bundle_action(request)
+
+  def _request_process_bundle_action(self, request):
 
     def task():
       instruction_reference = getattr(
@@ -304,6 +310,14 @@ class SdkWorker(object):
     processor.reset()
     self.cached_bundle_processors[bundle_descriptor_id].append(processor)
 
+  def process_bundle_split(self, request, instruction_id):
+    processor = self.active_bundle_processors.get(request.instruction_reference)
+    if not processor:
+      raise ValueError('Instruction not running: %s' % instruction_id)
+    return beam_fn_api_pb2.InstructionResponse(
+        instruction_id=instruction_id,
+        process_bundle_split=processor.try_split(request))
+
   def process_bundle_progress(self, request, instruction_id):
     # It is an error to get progress for a not-in-flight bundle.
     processor = self.active_bundle_processors.get(request.instruction_reference)


Mime
View raw message