beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rober...@apache.org
Subject [beam] branch master updated: Make SDFBoundedSource wrapper work with dynamic splitting (#8944)
Date Tue, 23 Jul 2019 11:48:14 GMT
This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 75875de  Make SDFBoundedSource wrapper work with dynamic splitting (#8944)
75875de is described below

commit 75875def7098dec8fcab89941ee398a34fbf5fb1
Author: Boyuan Zhang <36090911+boyuanzz@users.noreply.github.com>
AuthorDate: Tue Jul 23 04:47:45 2019 -0700

    Make SDFBoundedSource wrapper work with dynamic splitting (#8944)
---
 sdks/python/apache_beam/io/iobase.py         | 57 ++++++++++++++++++++--------
 sdks/python/apache_beam/io/iobase_test.py    | 30 +++++++++------
 sdks/python/apache_beam/io/range_trackers.py | 19 ++++------
 3 files changed, 68 insertions(+), 38 deletions(-)

diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index bb7c03c..6763c57 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -1346,15 +1346,27 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
 
     Delegated RangeTracker guarantees synchronization safety.
     """
-    def __init__(self, range_tracker):
-      if not isinstance(range_tracker, RangeTracker):
+    def __init__(self, restriction):
+      if not isinstance(restriction, SourceBundle):
         raise ValueError('Initializing SDFBoundedSourceRestrictionTracker'
-                         'requires a RangeTracker')
-      self._delegate_range_tracker = range_tracker
+                         'requires a SourceBundle')
+      self._delegate_range_tracker = restriction.source.get_range_tracker(
+          restriction.start_position, restriction.stop_position)
+      self._source = restriction.source
+      self._weight = restriction.weight
+
+    def current_progress(self):
+      return RestrictionProgress(
+          fraction=self._delegate_range_tracker.fraction_consumed())
 
     def current_restriction(self):
-      return (self._delegate_range_tracker.start_position(),
-              self._delegate_range_tracker.stop_position())
+      start_pos = self._delegate_range_tracker.start_position()
+      stop_pos = self._delegate_range_tracker.stop_position()
+      return SourceBundle(
+          self._weight,
+          self._source,
+          start_pos,
+          stop_pos)
 
     def start_pos(self):
       return self._delegate_range_tracker.start_position()
@@ -1373,15 +1385,32 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
       # Need to stash current stop_pos before splitting since
       # range_tracker.split will update its stop_pos if splits
       # successfully.
+      start_pos = self.start_pos()
       stop_pos = self.stop_pos()
-      split_pos, _ = self._delegate_range_tracker.try_split(position)
-      if split_pos:
-        return ((self._delegate_range_tracker.start_position(), split_pos),
-                (split_pos, stop_pos))
+      split_result = self._delegate_range_tracker.try_split(position)
+      if split_result:
+        split_pos, split_fraction = split_result
+        primary_weight = self._weight * split_fraction
+        residual_weight = self._weight - primary_weight
+        # Update self._weight to primary weight
+        self._weight = primary_weight
+        return (SourceBundle(primary_weight, self._source, start_pos,
+                             split_pos),
+                SourceBundle(residual_weight, self._source, split_pos,
+                             stop_pos))
 
     def deferred_status(self):
       return None
 
+    def current_watermark(self):
+      return None
+
+    def get_delegate_range_tracker(self):
+      return self._delegate_range_tracker
+
+    def get_tracking_source(self):
+      return self._source
+
   class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
     """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
 
@@ -1399,8 +1428,7 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
 
     def create_tracker(self, restriction):
       return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
-          restriction.source.get_range_tracker(restriction.start_position,
-                                               restriction.stop_position))
+          restriction)
 
     def split(self, element, restriction):
       # Invoke source.split to get initial splitting results.
@@ -1431,9 +1459,8 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
           restriction_tracker=core.DoFn.RestrictionParam(
               _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
                   source, chunk_size))):
-        start_pos, end_pos = restriction_tracker.current_restriction()
-        range_tracker = self.source.get_range_tracker(start_pos, end_pos)
-        return self.source.read(range_tracker)
+        return restriction_tracker.get_tracking_source().read(
+            restriction_tracker.get_delegate_range_tracker())
 
     return SDFBoundedSourceDoFn(self.source)
 
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 65fc89c..c7d1656 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -25,7 +25,6 @@ from apache_beam.io.concat_source import ConcatSource
 from apache_beam.io.concat_source_test import RangeSource
 from apache_beam.io import iobase
 from apache_beam.io.iobase import SourceBundle
-from apache_beam.io.range_trackers import OffsetRangeTracker
 
 
 class SDFBoundedSourceRestrictionProviderTest(unittest.TestCase):
@@ -115,14 +114,17 @@ class SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase):
   def setUp(self):
     self.initial_start_pos = 0
     self.initial_stop_pos = 4
-    self.range_tracker = OffsetRangeTracker(self.initial_start_pos,
-                                            self.initial_stop_pos)
+    source_bundle = SourceBundle(
+        self.initial_stop_pos - self.initial_start_pos,
+        RangeSource(self.initial_start_pos, self.initial_stop_pos),
+        self.initial_start_pos,
+        self.initial_stop_pos)
     self.sdf_restriction_tracker = (
         iobase._SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
-            self.range_tracker))
+            source_bundle))
 
   def test_current_restriction_before_split(self):
-    actual_start, actual_stop = (
+    _, _, actual_start, actual_stop = (
         self.sdf_restriction_tracker.current_restriction())
     self.assertEqual(self.initial_start_pos, actual_start)
     self.assertEqual(self.initial_stop_pos, actual_stop)
@@ -136,14 +138,20 @@ class SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase):
                      self.sdf_restriction_tracker.current_restriction())
 
   def test_try_split_at_remainder(self):
-    fraction_of_remainder = 0.5
-    expected_primary = (0, 3)
-    expected_residual = (3, 4)
-    self.sdf_restriction_tracker.try_claim(1)
+    fraction_of_remainder = 0.4
+    expected_primary = (0, 2, 2.0)
+    expected_residual = (2, 4, 2.0)
+    self.sdf_restriction_tracker.try_claim(0)
     actual_primary, actual_residual = (
         self.sdf_restriction_tracker.try_split(fraction_of_remainder))
-    self.assertEqual(expected_primary, actual_primary)
-    self.assertEqual(expected_residual, actual_residual)
+    self.assertEqual(expected_primary, (actual_primary.start_position,
+                                        actual_primary.stop_position,
+                                        actual_primary.weight))
+    self.assertEqual(expected_residual, (actual_residual.start_position,
+                                         actual_residual.stop_position,
+                                         actual_residual.weight))
+    self.assertEqual(actual_primary.weight,
+                     self.sdf_restriction_tracker._weight)
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py
index 5bf4898..c46f801 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -167,17 +167,19 @@ class OffsetRangeTracker(iobase.RangeTracker):
 
   def fraction_consumed(self):
     with self._lock:
-      fraction = ((1.0 * (self._last_record_start - self.start_position()) /
-                   (self.stop_position() - self.start_position())) if
-                  self.stop_position() != self.start_position() else 0.0)
-
       # self.last_record_start may become larger than self.end_offset when
       # reading the records since any record that starts before the first 'split
       # point' at or after the defined 'stop offset' is considered to be within
       # the range of the OffsetRangeTracker. Hence fraction could be > 1.
       # self.last_record_start is initialized to -1, hence fraction may be < 0.
       # Bounding the to range [0, 1].
-      return max(0.0, min(1.0, fraction))
+      return self.position_to_fraction(self._last_record_start,
+                                       self.start_position(),
+                                       self.stop_position())
+
+  def position_to_fraction(self, pos, start, stop):
+    fraction = 1.0 * (pos - start) / (stop - start) if start != stop else 0.0
+    return max(0.0, min(1.0, fraction))
 
   def position_at_fraction(self, fraction):
     if self.stop_position() == OffsetRangeTracker.OFFSET_INFINITY:
@@ -271,13 +273,6 @@ class OrderedPositionRangeTracker(iobase.RangeTracker):
       return self.position_to_fraction(
           self._last_claim, self._start_position, self._stop_position)
 
-  def position_to_fraction(self, pos, start, end):
-    """
-    Converts a position `pos` betweeen `start` and `end` (inclusive) to a
-    fraction between 0 and 1.
-    """
-    raise NotImplementedError
-
   def fraction_to_position(self, fraction, start, end):
     """
     Converts a fraction between 0 and 1 to a position between start and end.


Mime
View raw message