beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Commented] (BEAM-1872) implement Reshuffle transform in python, make it experimental in Java
Date Wed, 06 Dec 2017 21:32:00 GMT

    [ https://issues.apache.org/jira/browse/BEAM-1872?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16280959#comment-16280959
] 

ASF GitHub Bot commented on BEAM-1872:
--------------------------------------

robertwb closed pull request #4040: [BEAM-1872] Add IdentityWindowFn for use in Reshuffle
URL: https://github.com/apache/beam/pull/4040
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 689eab7b842..ccf2516eee9 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -565,9 +565,10 @@ class WindowIntoDoFn(beam.DoFn):
     def __init__(self, windowing):
       self.windowing = windowing
 
-    def process(self, element, timestamp=beam.DoFn.TimestampParam):
+    def process(self, element, timestamp=beam.DoFn.TimestampParam,
+                window=beam.DoFn.WindowParam):
       new_windows = self.windowing.windowfn.assign(
-          WindowFn.AssignContext(timestamp, element=element))
+          WindowFn.AssignContext(timestamp, element=element, window=window))
       yield WindowedValue(element, timestamp, new_windows)
   from apache_beam.transforms.core import Windowing
   from apache_beam.transforms.window import WindowFn, WindowedValue
diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py
index 34c15f9c191..2f18bdee0b1 100644
--- a/sdks/python/apache_beam/testing/util.py
+++ b/sdks/python/apache_beam/testing/util.py
@@ -19,13 +19,16 @@
 
 from __future__ import absolute_import
 
+import collections
 import glob
 import tempfile
 
 from apache_beam import pvalue
 from apache_beam.transforms import window
 from apache_beam.transforms.core import Create
+from apache_beam.transforms.core import DoFn
 from apache_beam.transforms.core import Map
+from apache_beam.transforms.core import ParDo
 from apache_beam.transforms.core import WindowInto
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.util import CoGroupByKey
@@ -37,6 +40,7 @@
     'is_empty',
     # open_shards is internal and has no backwards compatibility guarantees.
     'open_shards',
+    'TestWindowedValue',
     ]
 
 
@@ -46,11 +50,32 @@ class BeamAssertException(Exception):
   pass
 
 
+# Used for reifying timestamps and windows for assert_that matchers.
+TestWindowedValue = collections.namedtuple(
+    'TestWindowedValue', 'value timestamp windows')
+
+
+def contains_in_any_order(iterable):
+  """Creates an object that matches another iterable if they both have the
+  same count of items.
+
+  Arguments:
+    iterable: An iterable of hashable objects.
+  """
+  class InAnyOrder(object):
+    def __init__(self, iterable):
+      self._counter = collections.Counter(iterable)
+
+    def __eq__(self, other):
+      return self._counter == collections.Counter(other)
+
+  return InAnyOrder(iterable)
+
+
 # Note that equal_to always sorts the expected and actual since what we
 # compare are PCollections for which there is no guaranteed order.
 # However the sorting does not go beyond top level therefore [1,2] and [2,1]
 # are considered equal and [[1,2]] and [[2,1]] are not.
-# TODO(silviuc): Add contains_in_any_order-style matchers.
 def equal_to(expected):
   expected = list(expected)
 
@@ -72,7 +97,7 @@ def _empty(actual):
   return _empty
 
 
-def assert_that(actual, matcher, label='assert_that'):
+def assert_that(actual, matcher, label='assert_that', reify_windows=False):
   """A PTransform that checks a PCollection has an expected value.
 
   Note that assert_that should be used only for testing pipelines since the
@@ -85,15 +110,27 @@ def assert_that(actual, matcher, label='assert_that'):
       expectations and raises BeamAssertException if they are not met.
     label: Optional string label. This is needed in case several assert_that
       transforms are introduced in the same pipeline.
+    reify_windows: If True, matcher is passed a list of TestWindowedValue.
 
   Returns:
     Ignored.
   """
   assert isinstance(actual, pvalue.PCollection)
 
+  class ReifyTimestampWindow(DoFn):
+    def process(self, element, timestamp=DoFn.TimestampParam,
+                window=DoFn.WindowParam):
+      # This returns TestWindowedValue instead of
+      # beam.utils.windowed_value.WindowedValue because ParDo will extract
+      # the timestamp and window out of the latter.
+      return [TestWindowedValue(element, timestamp, [window])]
+
   class AssertThat(PTransform):
 
     def expand(self, pcoll):
+      if reify_windows:
+        pcoll = pcoll | ParDo(ReifyTimestampWindow())
+
       # We must have at least a single element to ensure the matcher
       # code gets run even if the input pcollection is empty.
       keyed_singleton = pcoll.pipeline | Create([(None, None)])
diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py
index 9d3869381b6..e4e86941669 100644
--- a/sdks/python/apache_beam/testing/util_test.py
+++ b/sdks/python/apache_beam/testing/util_test.py
@@ -21,9 +21,13 @@
 
 from apache_beam import Create
 from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.testing.util import is_empty
+from apache_beam.transforms.window import GlobalWindow
+from apache_beam.transforms.window import IntervalWindow
+from apache_beam.utils.timestamp import MIN_TIMESTAMP
 
 
 class UtilTest(unittest.TestCase):
@@ -32,11 +36,49 @@ def test_assert_that_passes(self):
     with TestPipeline() as p:
       assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3]))
 
+  def test_assert_that_passes_empty_equal_to(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([]), equal_to([]))
+
+  def test_assert_that_passes_empty_is_empty(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([]), is_empty())
+
+  def test_windowed_value_passes(self):
+    expected = [TestWindowedValue(v, MIN_TIMESTAMP, [GlobalWindow()])
+                for v in [1, 2, 3]]
+    with TestPipeline() as p:
+      assert_that(p | Create([2, 3, 1]), equal_to(expected), reify_windows=True)
+
   def test_assert_that_fails(self):
     with self.assertRaises(Exception):
       with TestPipeline() as p:
         assert_that(p | Create([1, 10, 100]), equal_to([1, 2, 3]))
 
+  def test_windowed_value_assert_fail_unmatched_value(self):
+    expected = [TestWindowedValue(v + 1, MIN_TIMESTAMP, [GlobalWindow()])
+                for v in [1, 2, 3]]
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([2, 3, 1]), equal_to(expected),
+                    reify_windows=True)
+
+  def test_windowed_value_assert_fail_unmatched_timestamp(self):
+    expected = [TestWindowedValue(v, 1, [GlobalWindow()])
+                for v in [1, 2, 3]]
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([2, 3, 1]), equal_to(expected),
+                    reify_windows=True)
+
+  def test_windowed_value_assert_fail_unmatched_window(self):
+    expected = [TestWindowedValue(v, MIN_TIMESTAMP, [IntervalWindow(0, 1)])
+                for v in [1, 2, 3]]
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([2, 3, 1]), equal_to(expected),
+                    reify_windows=True)
+
   def test_assert_that_fails_on_empty_input(self):
     with self.assertRaises(Exception):
       with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index e650b399a07..533634dba58 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1579,8 +1579,10 @@ class WindowIntoFn(DoFn):
     def __init__(self, windowing):
       self.windowing = windowing
 
-    def process(self, element, timestamp=DoFn.TimestampParam):
-      context = WindowFn.AssignContext(timestamp, element=element)
+    def process(self, element, timestamp=DoFn.TimestampParam,
+                window=DoFn.WindowParam):
+      context = WindowFn.AssignContext(timestamp, element=element,
+                                       window=window)
       new_windows = self.windowing.windowfn.assign(context)
       yield WindowedValue(element, context.timestamp, new_windows)
 
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 85d4975e3f5..332387ad414 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -22,6 +22,7 @@
 
 import collections
 import contextlib
+import random
 import time
 
 from apache_beam import typehints
@@ -29,12 +30,20 @@
 from apache_beam.transforms import window
 from apache_beam.transforms.core import CombinePerKey
 from apache_beam.transforms.core import DoFn
+from apache_beam.transforms.core import FlatMap
 from apache_beam.transforms.core import Flatten
 from apache_beam.transforms.core import GroupByKey
 from apache_beam.transforms.core import Map
 from apache_beam.transforms.core import ParDo
+from apache_beam.transforms.core import WindowInto
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.ptransform import ptransform_fn
+from apache_beam.transforms.trigger import AccumulationMode
+from apache_beam.transforms.trigger import AfterCount
+from apache_beam.transforms.window import NonMergingWindowFn
+from apache_beam.transforms.window import TimestampCombiner
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.utils import urns
 from apache_beam.utils import windowed_value
 
 __all__ = [
@@ -43,10 +52,12 @@
     'Keys',
     'KvSwap',
     'RemoveDuplicates',
+    'Reshuffle',
     'Values',
     ]
 
-
+K = typehints.TypeVariable('K')
+V = typehints.TypeVariable('V')
 T = typehints.TypeVariable('T')
 
 
@@ -423,3 +434,102 @@ def expand(self, pcoll):
           self._batch_size_estimator))
     else:
       return pcoll | ParDo(_WindowAwareBatchingDoFn(self._batch_size_estimator))
+
+
+class _IdentityWindowFn(NonMergingWindowFn):
+  """Windowing function that preserves existing windows.
+
+  To be used internally with the Reshuffle transform.
+  Will raise an exception when used after DoFns that return TimestampedValue
+  elements.
+  """
+
+  def __init__(self, window_coder):
+    """Create a new WindowFn with compatible coder.
+    To be applied to PCollections with windows that are compatible with the
+    given coder.
+
+    Arguments:
+      window_coder: coders.Coder object to be used on windows.
+    """
+    super(_IdentityWindowFn, self).__init__()
+    if window_coder is None:
+      raise ValueError('window_coder should not be None')
+    self._window_coder = window_coder
+
+  def assign(self, assign_context):
+    if assign_context.window is None:
+      raise ValueError(
+          'assign_context.window should not be None. '
+          'This might be due to a DoFn returning a TimestampedValue.')
+    return [assign_context.window]
+
+  def get_window_coder(self):
+    return self._window_coder
+
+  def to_runner_api_parameter(self, unused_context):
+    pass  # Overridden by register_pickle_urn below.
+
+  urns.RunnerApiFn.register_pickle_urn(urns.RESHUFFLE_TRANSFORM)
+
+
+@typehints.with_input_types(typehints.KV[K, V])
+@typehints.with_output_types(typehints.KV[K, V])
+class ReshufflePerKey(PTransform):
+  """PTransform that returns a PCollection equivalent to its input,
+  but operationally provides some of the side effects of a GroupByKey,
+  in particular preventing fusion of the surrounding transforms,
+  checkpointing, and deduplication by id.
+
+  ReshufflePerKey is experimental. No backwards compatibility guarantees.
+  """
+
+  def expand(self, pcoll):
+    class ReifyTimestamps(DoFn):
+      def process(self, element, timestamp=DoFn.TimestampParam):
+        yield element[0], TimestampedValue(element[1], timestamp)
+
+    class RestoreTimestamps(DoFn):
+      def process(self, element, window=DoFn.WindowParam):
+        # Pass the current window since _IdentityWindowFn wouldn't know how
+        # to generate it.
+        yield windowed_value.WindowedValue(
+            (element[0], element[1].value), element[1].timestamp, [window])
+
+    windowing_saved = pcoll.windowing
+    result = (pcoll
+              | ParDo(ReifyTimestamps())
+              | 'IdentityWindow' >> WindowInto(
+                  _IdentityWindowFn(
+                      windowing_saved.windowfn.get_window_coder()),
+                  trigger=AfterCount(1),
+                  accumulation_mode=AccumulationMode.DISCARDING,
+                  timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST,
+                  )
+              | GroupByKey()
+              | 'ExpandIterable' >> FlatMap(
+                  lambda e: [(e[0], value) for value in e[1]])
+              | ParDo(RestoreTimestamps()))
+    result._windowing = windowing_saved
+    return result
+
+
+@typehints.with_input_types(T)
+@typehints.with_output_types(T)
+class Reshuffle(PTransform):
+  """PTransform that returns a PCollection equivalent to its input,
+  but operationally provides some of the side effects of a GroupByKey,
+  in particular preventing fusion of the surrounding transforms,
+  checkpointing, and deduplication by id.
+
+  Reshuffle adds a temporary random key to each element, performs a
+  ReshufflePerKey, and finally removes the temporary key.
+
+  Reshuffle is experimental. No backwards compatibility guarantees.
+  """
+
+  def expand(self, pcoll):
+    return (pcoll
+            | 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t))
+            | ReshufflePerKey()
+            | 'RemoveRandomKeys' >> Map(lambda t: t[1]))
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 6064e2ccce1..0be418028be 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -21,11 +21,24 @@
 import unittest
 
 import apache_beam as beam
+from apache_beam.coders import coders
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import contains_in_any_order
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import util
 from apache_beam.transforms import window
+from apache_beam.transforms.window import GlobalWindow
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.transforms.window import IntervalWindow
+from apache_beam.transforms.window import Sessions
+from apache_beam.transforms.window import SlidingWindows
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.utils import timestamp
+from apache_beam.utils.windowed_value import WindowedValue
 
 
 class FakeClock(object):
@@ -106,3 +119,220 @@ def test_target_overhead(self):
       with batch_estimator.record_time(actual_sizes[-1]):
         clock.sleep(batch_duration(actual_sizes[-1]))
     self.assertEqual(expected_sizes, actual_sizes)
+
+
+class IdentityWindowTest(unittest.TestCase):
+
+  def test_window_preserved(self):
+    expected_timestamp = timestamp.Timestamp(5)
+    expected_window = window.IntervalWindow(1.0, 2.0)
+
+    class AddWindowDoFn(beam.DoFn):
+      def process(self, element):
+        yield WindowedValue(
+            element, expected_timestamp, [expected_window])
+
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_windows = [
+        TestWindowedValue(kv, expected_timestamp, [expected_window])
+        for kv in data]
+    before_identity = (pipeline
+                       | 'start' >> beam.Create(data)
+                       | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
+    assert_that(before_identity, equal_to(expected_windows),
+                label='before_identity', reify_windows=True)
+    after_identity = (before_identity
+                      | 'window' >> beam.WindowInto(
+                          beam.transforms.util._IdentityWindowFn(
+                              coders.IntervalWindowCoder())))
+    assert_that(after_identity, equal_to(expected_windows),
+                label='after_identity', reify_windows=True)
+    pipeline.run()
+
+  def test_no_window_context_fails(self):
+    expected_timestamp = timestamp.Timestamp(5)
+    # Assuming the default window function is window.GlobalWindows.
+    expected_window = window.GlobalWindow()
+
+    class AddTimestampDoFn(beam.DoFn):
+      def process(self, element):
+        yield window.TimestampedValue(element, expected_timestamp)
+
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_windows = [
+        TestWindowedValue(kv, expected_timestamp, [expected_window])
+        for kv in data]
+    before_identity = (pipeline
+                       | 'start' >> beam.Create(data)
+                       | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
+    assert_that(before_identity, equal_to(expected_windows),
+                label='before_identity', reify_windows=True)
+    after_identity = (before_identity
+                      | 'window' >> beam.WindowInto(
+                          beam.transforms.util._IdentityWindowFn(
+                              coders.GlobalWindowCoder()))
+                      # This DoFn will return TimestampedValues, making
+                      # WindowFn.AssignContext passed to IdentityWindowFn
+                      # contain a window of None. IdentityWindowFn should
+                      # raise an exception.
+                      | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
+    assert_that(after_identity, equal_to(expected_windows),
+                label='after_identity', reify_windows=True)
+    with self.assertRaisesRegexp(ValueError, r'window.*None.*add_timestamps2'):
+      pipeline.run()
+
+
+class ReshuffleTest(unittest.TestCase):
+
+  def test_reshuffle_contents_unchanged(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+    result = (pipeline
+              | 'start' >> beam.Create(data)
+              | 'reshuffle' >> beam.Reshuffle())
+    assert_that(result, equal_to(data))
+    pipeline.run()
+
+  def test_reshuffle_after_gbk_contents_unchanged(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+    expected_result = [(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]
+
+    after_gbk = (pipeline
+                 | 'start' >> beam.Create(data)
+                 | 'group_by_key' >> beam.GroupByKey())
+    assert_that(after_gbk, equal_to(expected_result), label='after_gbk')
+    after_reshuffle = (after_gbk
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_result),
+                label='after_reshuffle')
+    pipeline.run()
+
+  def test_reshuffle_timestamps_unchanged(self):
+    pipeline = TestPipeline()
+    timestamp = 5
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+    expected_result = [TestWindowedValue(v, timestamp, [GlobalWindow()])
+                       for v in data]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'add_timestamp' >> beam.Map(
+                            lambda v: beam.window.TimestampedValue(v,
+                                                                   timestamp)))
+    assert_that(before_reshuffle, equal_to(expected_result),
+                label='before_reshuffle', reify_windows=True)
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_result),
+                label='after_reshuffle', reify_windows=True)
+    pipeline.run()
+
+  def test_reshuffle_windows_unchanged(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_data = [TestWindowedValue(v, t, [w]) for (v, t, w) in
+                     [((1, [2, 1]), 4.0, IntervalWindow(1.0, 4.0)),
+                      ((2, [2, 1]), 4.0, IntervalWindow(1.0, 4.0)),
+                      ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
+                      ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'add_timestamp' >> beam.Map(
+                            lambda v: beam.window.TimestampedValue(v, v[1]))
+                        | 'window' >> beam.WindowInto(Sessions(gap_size=2))
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle', reify_windows=True)
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle', reify_windows=True)
+    pipeline.run()
+
+  def test_reshuffle_window_fn_preserved(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
+        ((1, 1), 1.0, IntervalWindow(1.0, 3.0)),
+        ((2, 1), 1.0, IntervalWindow(1.0, 3.0)),
+        ((3, 1), 1.0, IntervalWindow(1.0, 3.0)),
+        ((1, 2), 2.0, IntervalWindow(2.0, 4.0)),
+        ((2, 2), 2.0, IntervalWindow(2.0, 4.0)),
+        ((1, 4), 4.0, IntervalWindow(4.0, 6.0))]]
+    expected_merged_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
+        ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+        ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+        ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
+        ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'add_timestamp' >> beam.Map(
+                            lambda v: TimestampedValue(v, v[1]))
+                        | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
+    assert_that(before_reshuffle, equal_to(expected_windows),
+                label='before_reshuffle', reify_windows=True)
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_windows),
+                label='after_reshuffle', reify_windows=True)
+    after_group = (after_reshuffle
+                   | 'group_by_key' >> beam.GroupByKey())
+    assert_that(after_group, equal_to(expected_merged_windows),
+                label='after_group', reify_windows=True)
+    pipeline.run()
+
+  def test_reshuffle_global_window(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'window' >> beam.WindowInto(GlobalWindows())
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle')
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle')
+    pipeline.run()
+
+  def test_reshuffle_sliding_window(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    window_size = 2
+    expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] * window_size
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'window' >> beam.WindowInto(SlidingWindows(
+                            size=window_size, period=1))
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle')
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    # If Reshuffle applies the sliding window function a second time there
+    # should be extra values for each key.
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle')
+    pipeline.run()
+
+  def test_reshuffle_streaming_global_window(self):
+    options = PipelineOptions()
+    options.view_as(StandardOptions).streaming = True
+    pipeline = TestPipeline(options=options)
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'window' >> beam.WindowInto(GlobalWindows())
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle')
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle')
+    pipeline.run()
diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py
index 8c8bf336bab..ee9d6f97187 100644
--- a/sdks/python/apache_beam/transforms/window.py
+++ b/sdks/python/apache_beam/transforms/window.py
@@ -114,13 +114,21 @@ class WindowFn(urns.RunnerApiFn):
   class AssignContext(object):
     """Context passed to WindowFn.assign()."""
 
-    def __init__(self, timestamp, element=None):
+    def __init__(self, timestamp, element=None, window=None):
       self.timestamp = Timestamp.of(timestamp)
       self.element = element
+      self.window = window
 
   @abc.abstractmethod
   def assign(self, assign_context):
-    """Associates a timestamp to an element."""
+    """Associates windows to an element.
+
+    Arguments:
+      assign_context: Instance of AssignContext.
+
+    Returns:
+      An iterable of BoundedWindow.
+    """
     raise NotImplementedError
 
   class MergeContext(object):
diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py
index 1359f323bd9..387c8d6a337 100644
--- a/sdks/python/apache_beam/utils/urns.py
+++ b/sdks/python/apache_beam/utils/urns.py
@@ -44,6 +44,7 @@
 COMBINE_GROUPED_VALUES_TRANSFORM = "beam:ptransform:combine_grouped_values:v0.1"
 FLATTEN_TRANSFORM = "beam:ptransform:flatten:v0.1"
 READ_TRANSFORM = "beam:ptransform:read:v0.1"
+RESHUFFLE_TRANSFORM = "beam:ptransform:reshuffle:v0.1"
 WINDOW_INTO_TRANSFORM = "beam:ptransform:window_into:v0.1"
 
 PICKLED_SOURCE = "beam:source:pickled_python:v0.1"
@@ -90,9 +91,9 @@ def to_runner_api_parameter(self, unused_context):
 
   @classmethod
   def register_urn(cls, urn, parameter_type, fn=None):
-    """Registeres a urn with a constructor.
+    """Registers a urn with a constructor.
 
-    For example, if 'beam:fn:foo' had paramter type FooPayload, one could
+    For example, if 'beam:fn:foo' had parameter type FooPayload, one could
     write `RunnerApiFn.register_urn('bean:fn:foo', FooPayload, foo_from_proto)`
     where foo_from_proto took as arguments a FooPayload and a PipelineContext.
     This function can also be used as a decorator rather than passing the


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> implement Reshuffle transform in python, make it experimental in Java
> ---------------------------------------------------------------------
>
>                 Key: BEAM-1872
>                 URL: https://issues.apache.org/jira/browse/BEAM-1872
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-java-core, sdk-py-core
>            Reporter: Ahmet Altay
>            Assignee: Udi Meiri
>              Labels: sdk-consistency
>




--
This message was sent by Atlassian JIRA
(v6.4.14#64029)

Mime
View raw message