beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [1/2] beam git commit: Migrate DirectRunner evaluators to use Beam state API
Date Thu, 15 Jun 2017 17:07:49 GMT
Repository: beam
Updated Branches:
  refs/heads/master 42c88f415 -> 6fc70b35f


Migrate DirectRunner evaluators to use Beam state API


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/be09a162
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/be09a162
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/be09a162

Branch: refs/heads/master
Commit: be09a162e32d158f5ae043e064223bb4f3742648
Parents: 42c88f4
Author: Charles Chen <ccy@google.com>
Authored: Wed Jun 14 16:14:50 2017 -0700
Committer: Charles Chen <ccy@google.com>
Committed: Thu Jun 15 01:39:08 2017 -0700

----------------------------------------------------------------------
 .../runners/dataflow/native_io/iobase_test.py   | 39 ++++++++++-
 .../runners/direct/evaluation_context.py        | 56 +++++++++++----
 .../runners/direct/transform_evaluator.py       | 74 +++++++++++---------
 .../runners/direct/transform_result.py          |  3 +-
 4 files changed, 122 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/be09a162/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
index 7610baf..3d8c24f 100644
--- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
@@ -20,7 +20,9 @@
 
 import unittest
 
-from apache_beam import error, pvalue
+from apache_beam import Create
+from apache_beam import error
+from apache_beam import pvalue
 from apache_beam.runners.dataflow.native_io.iobase import (
     _dict_printable_fields,
     _NativeWrite,
@@ -28,10 +30,12 @@ from apache_beam.runners.dataflow.native_io.iobase import (
     DynamicSplitRequest,
     DynamicSplitResultWithPosition,
     NativeSink,
+    NativeSinkWriter,
     NativeSource,
     ReaderPosition,
     ReaderProgress
 )
+from apache_beam.testing.test_pipeline import TestPipeline
 
 
 class TestHelperFunctions(unittest.TestCase):
@@ -154,6 +158,39 @@ class TestNativeSink(unittest.TestCase):
     fake_sink = FakeSink()
     self.assertEqual(fake_sink.__repr__(), "<FakeSink ['validate=False']>")
 
+  def test_on_direct_runner(self):
+    class FakeSink(NativeSink):
+      """A fake sink outputing a number of elements."""
+
+      def __init__(self):
+        self.written_values = []
+        self.writer_instance = FakeSinkWriter(self.written_values)
+
+      def writer(self):
+        return self.writer_instance
+
+    class FakeSinkWriter(NativeSinkWriter):
+      """A fake sink writer for testing."""
+
+      def __init__(self, written_values):
+        self.written_values = written_values
+
+      def __enter__(self):
+        return self
+
+      def __exit__(self, *unused_args):
+        pass
+
+      def Write(self, value):
+        self.written_values.append(value)
+
+    p = TestPipeline()
+    sink = FakeSink()
+    p | Create(['a', 'b', 'c']) | _NativeWrite(sink)  # pylint: disable=expression-not-assigned
+    p.run()
+
+    self.assertEqual(['a', 'b', 'c'], sink.written_values)
+
 
 class Test_NativeWrite(unittest.TestCase):
 

http://git-wip-us.apache.org/repos/asf/beam/blob/be09a162/sdks/python/apache_beam/runners/direct/evaluation_context.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py b/sdks/python/apache_beam/runners/direct/evaluation_context.py
index 68d99d3..8fa8e06 100644
--- a/sdks/python/apache_beam/runners/direct/evaluation_context.py
+++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py
@@ -27,22 +27,22 @@ from apache_beam.runners.direct.clock import Clock
 from apache_beam.runners.direct.watermark_manager import WatermarkManager
 from apache_beam.runners.direct.executor import TransformExecutor
 from apache_beam.runners.direct.direct_metrics import DirectMetrics
+from apache_beam.transforms.trigger import InMemoryUnmergedState
 from apache_beam.utils import counters
 
 
 class _ExecutionContext(object):
 
-  def __init__(self, watermarks, existing_state):
-    self._watermarks = watermarks
-    self._existing_state = existing_state
+  def __init__(self, watermarks, keyed_states):
+    self.watermarks = watermarks
+    self.keyed_states = keyed_states
 
-  @property
-  def watermarks(self):
-    return self._watermarks
+    self._step_context = None
 
-  @property
-  def existing_state(self):
-    return self._existing_state
+  def get_step_context(self):
+    if not self._step_context:
+      self._step_context = DirectStepContext(self.keyed_states)
+    return self._step_context
 
 
 class _SideInputView(object):
@@ -145,9 +145,8 @@ class EvaluationContext(object):
     self._pcollection_to_views = collections.defaultdict(list)
     for view in views:
       self._pcollection_to_views[view.pvalue].append(view)
-
-    # AppliedPTransform -> Evaluator specific state objects
-    self._application_state_interals = {}
+    self._transform_keyed_states = self._initialize_keyed_states(
+        root_transforms, value_to_consumers)
     self._watermark_manager = WatermarkManager(
         Clock(), root_transforms, value_to_consumers)
     self._side_inputs_container = _SideInputsContainer(views)
@@ -158,6 +157,15 @@ class EvaluationContext(object):
 
     self._lock = threading.Lock()
 
+  def _initialize_keyed_states(self, root_transforms, value_to_consumers):
+    transform_keyed_states = {}
+    for transform in root_transforms:
+      transform_keyed_states[transform] = {}
+    for consumers in value_to_consumers.values():
+      for consumer in consumers:
+        transform_keyed_states[consumer] = {}
+    return transform_keyed_states
+
   def use_pvalue_cache(self, cache):
     assert not self._cache
     self._cache = cache
@@ -231,7 +239,6 @@ class EvaluationContext(object):
               counter.name, counter.combine_fn)
           merged_counter.accumulator.merge([counter.accumulator])
 
-      self._application_state_interals[result.transform] = result.state
       return committed_bundles
 
   def get_aggregator_values(self, aggregator_or_name):
@@ -256,7 +263,7 @@ class EvaluationContext(object):
   def get_execution_context(self, applied_ptransform):
     return _ExecutionContext(
         self._watermark_manager.get_watermarks(applied_ptransform),
-        self._application_state_interals.get(applied_ptransform))
+        self._transform_keyed_states[applied_ptransform])
 
   def create_bundle(self, output_pcollection):
     """Create an uncommitted bundle for the specified PCollection."""
@@ -296,3 +303,24 @@ class EvaluationContext(object):
     assert isinstance(task, TransformExecutor)
     return self._side_inputs_container.get_value_or_schedule_after_output(
         side_input, task)
+
+
+class DirectUnmergedState(InMemoryUnmergedState):
+  """UnmergedState implementation for the DirectRunner."""
+
+  def __init__(self):
+    super(DirectUnmergedState, self).__init__(defensive_copy=False)
+
+
+class DirectStepContext(object):
+  """Context for the currently-executing step."""
+
+  def __init__(self, keyed_existing_state):
+    self.keyed_existing_state = keyed_existing_state
+
+  def get_keyed_state(self, key):
+    # TODO(ccy): consider implementing transactional copy on write semantics
+    # for state so that work items can be safely retried.
+    if not self.keyed_existing_state.get(key):
+      self.keyed_existing_state[key] = DirectUnmergedState()
+    return self.keyed_existing_state[key]

http://git-wip-us.apache.org/repos/asf/beam/blob/be09a162/sdks/python/apache_beam/runners/direct/transform_evaluator.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index b1cb626..f5b5db5 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -33,6 +33,8 @@ from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite 
# pylint
 from apache_beam.transforms import core
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.transforms.window import WindowedValue
+from apache_beam.transforms.trigger import _CombiningValueStateTag
+from apache_beam.transforms.trigger import _ListStateTag
 from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn
 from apache_beam.typehints.typecheck import TypeCheckError
 from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn
@@ -207,7 +209,7 @@ class _BoundedReadEvaluator(_TransformEvaluator):
         bundles = _read_values_to_bundles(reader)
 
     return TransformResult(
-        self._applied_ptransform, bundles, None, None, None, None)
+        self._applied_ptransform, bundles, None, None, None)
 
 
 class _FlattenEvaluator(_TransformEvaluator):
@@ -231,7 +233,7 @@ class _FlattenEvaluator(_TransformEvaluator):
   def finish_bundle(self):
     bundles = [self.bundle]
     return TransformResult(
-        self._applied_ptransform, bundles, None, None, None, None)
+        self._applied_ptransform, bundles, None, None, None)
 
 
 class _TaggedReceivers(dict):
@@ -320,7 +322,7 @@ class _ParDoEvaluator(_TransformEvaluator):
     bundles = self._tagged_receivers.values()
     result_counters = self._counter_factory.get_counters()
     return TransformResult(
-        self._applied_ptransform, bundles, None, None, result_counters, None,
+        self._applied_ptransform, bundles, None, result_counters, None,
         self._tagged_receivers.undeclared_in_memory_tag_values)
 
 
@@ -328,13 +330,8 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
   """TransformEvaluator for _GroupByKeyOnly transform."""
 
   MAX_ELEMENT_PER_BUNDLE = None
-
-  class _GroupByKeyOnlyEvaluatorState(object):
-
-    def __init__(self):
-      # output: {} key -> [values]
-      self.output = collections.defaultdict(list)
-      self.completed = False
+  ELEMENTS_TAG = _ListStateTag('elements')
+  COMPLETION_TAG = _CombiningValueStateTag('completed', any)
 
   def __init__(self, evaluation_context, applied_ptransform,
                input_committed_bundle, side_inputs, scoped_metrics_container):
@@ -349,9 +346,8 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
             == WatermarkManager.WATERMARK_POS_INF)
 
   def start_bundle(self):
-    self.state = (self._execution_context.existing_state
-                  if self._execution_context.existing_state
-                  else _GroupByKeyOnlyEvaluator._GroupByKeyOnlyEvaluatorState())
+    self.step_context = self._execution_context.get_step_context()
+    self.global_state = self.step_context.get_keyed_state(None)
 
     assert len(self._outputs) == 1
     self.output_pcollection = list(self._outputs)[0]
@@ -362,12 +358,15 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
     self.key_coder = coders.registry.get_coder(kv_type_hint[0].tuple_types[0])
 
   def process_element(self, element):
-    assert not self.state.completed
+    assert not self.global_state.get_state(
+        None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG)
     if (isinstance(element, WindowedValue)
         and isinstance(element.value, collections.Iterable)
         and len(element.value) == 2):
       k, v = element.value
-      self.state.output[self.key_coder.encode(k)].append(v)
+      encoded_k = self.key_coder.encode(k)
+      state = self.step_context.get_keyed_state(encoded_k)
+      state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
     else:
       raise TypeCheckError('Input to _GroupByKeyOnly must be a PCollection of '
                            'windowed key-value pairs. Instead received: %r.'
@@ -375,15 +374,23 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
 
   def finish_bundle(self):
     if self._is_final_bundle:
-      if self.state.completed:
+      if self.global_state.get_state(
+          None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG):
         # Ignore empty bundles after emitting output. (This may happen because
         # empty bundles do not affect input watermarks.)
         bundles = []
       else:
-        gbk_result = (
-            map(GlobalWindows.windowed_value, (
-                (self.key_coder.decode(k), v)
-                for k, v in self.state.output.iteritems())))
+        gbk_result = []
+        # TODO(ccy): perhaps we can clean this up to not use this
+        # internal attribute of the DirectStepContext.
+        for encoded_k in self.step_context.keyed_existing_state:
+          # Ignore global state.
+          if encoded_k is None:
+            continue
+          k = self.key_coder.decode(encoded_k)
+          state = self.step_context.get_keyed_state(encoded_k)
+          vs = state.get_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG)
+          gbk_result.append(GlobalWindows.windowed_value((k, vs)))
 
         def len_element_fn(element):
           _, v = element.value
@@ -393,21 +400,22 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
             self.output_pcollection, gbk_result,
             _GroupByKeyOnlyEvaluator.MAX_ELEMENT_PER_BUNDLE, len_element_fn)
 
-      self.state.completed = True
-      state = self.state
+      self.global_state.add_state(
+          None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG, True)
       hold = WatermarkManager.WATERMARK_POS_INF
     else:
       bundles = []
-      state = self.state
       hold = WatermarkManager.WATERMARK_NEG_INF
 
     return TransformResult(
-        self._applied_ptransform, bundles, state, None, None, hold)
+        self._applied_ptransform, bundles, None, None, hold)
 
 
 class _NativeWriteEvaluator(_TransformEvaluator):
   """TransformEvaluator for _NativeWrite transform."""
 
+  ELEMENTS_TAG = _ListStateTag('elements')
+
   def __init__(self, evaluation_context, applied_ptransform,
                input_committed_bundle, side_inputs, scoped_metrics_container):
     assert not side_inputs
@@ -429,12 +437,12 @@ class _NativeWriteEvaluator(_TransformEvaluator):
             == WatermarkManager.WATERMARK_POS_INF)
 
   def start_bundle(self):
-    # state: [values]
-    self.state = (self._execution_context.existing_state
-                  if self._execution_context.existing_state else [])
+    self.step_context = self._execution_context.get_step_context()
+    self.global_state = self.step_context.get_keyed_state(None)
 
   def process_element(self, element):
-    self.state.append(element)
+    self.global_state.add_state(
+        None, _NativeWriteEvaluator.ELEMENTS_TAG, element)
 
   def finish_bundle(self):
     # finish_bundle will append incoming bundles in memory until all the bundles
@@ -444,19 +452,19 @@ class _NativeWriteEvaluator(_TransformEvaluator):
     # ignored and would not generate additional output files.
     # TODO(altay): Do not wait until the last bundle to write in a single shard.
     if self._is_final_bundle:
+      elements = self.global_state.get_state(
+          None, _NativeWriteEvaluator.ELEMENTS_TAG)
       if self._has_already_produced_output:
         # Ignore empty bundles that arrive after the output is produced.
-        assert self.state == []
+        assert elements == []
       else:
         self._sink.pipeline_options = self._evaluation_context.pipeline_options
         with self._sink.writer() as writer:
-          for v in self.state:
+          for v in elements:
             writer.Write(v.value)
-      state = None
       hold = WatermarkManager.WATERMARK_POS_INF
     else:
-      state = self.state
       hold = WatermarkManager.WATERMARK_NEG_INF
 
     return TransformResult(
-        self._applied_ptransform, [], state, None, None, hold)
+        self._applied_ptransform, [], None, None, hold)

http://git-wip-us.apache.org/repos/asf/beam/blob/be09a162/sdks/python/apache_beam/runners/direct/transform_result.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/transform_result.py b/sdks/python/apache_beam/runners/direct/transform_result.py
index febdd20..51593e3 100644
--- a/sdks/python/apache_beam/runners/direct/transform_result.py
+++ b/sdks/python/apache_beam/runners/direct/transform_result.py
@@ -25,12 +25,11 @@ class TransformResult(object):
 
   The result of evaluating an AppliedPTransform with a TransformEvaluator."""
 
-  def __init__(self, applied_ptransform, uncommitted_output_bundles, state,
+  def __init__(self, applied_ptransform, uncommitted_output_bundles,
                timer_update, counters, watermark_hold,
                undeclared_tag_values=None):
     self.transform = applied_ptransform
     self.uncommitted_output_bundles = uncommitted_output_bundles
-    self.state = state
     # TODO: timer update is currently unused.
     self.timer_update = timer_update
     self.counters = counters


Mime
View raw message