beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From chamik...@apache.org
Subject [1/2] beam git commit: Adds ability to dynamically replace PTransforms during runtime.
Date Tue, 13 Jun 2017 18:47:47 GMT
Repository: beam
Updated Branches:
  refs/heads/master c33e9b446 -> 7d0f24a0d


Adds ability to dynamically replace PTransforms during runtime.

To this end, adds two interfaces, PTransformMatcher and PTransformOverride.

Currently only supports replacements where input and output types are an exact match (we have
to address complexities due to type hints before supporting replacements with different types).

This will be used by SplittableDoFn where matching ParDo transforms will be dynamically replaced
by SplittableParDo.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/33662b92
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/33662b92
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/33662b92

Branch: refs/heads/master
Commit: 33662b92f1a9936f594223c9aee1a7233f59a569
Parents: c33e9b4
Author: chamikara@google.com <chamikara@google.com>
Authored: Thu Jun 8 14:56:24 2017 -0700
Committer: chamikara@google.com <chamikara@google.com>
Committed: Tue Jun 13 11:46:42 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/pipeline.py             | 201 +++++++++++++++++++
 sdks/python/apache_beam/pipeline_test.py        |  35 ++++
 .../apache_beam/runners/direct/direct_runner.py |  11 +
 3 files changed, 247 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/33662b92/sdks/python/apache_beam/pipeline.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index cea7215..05715d7 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -45,6 +45,7 @@ Typical usage:
 
 from __future__ import absolute_import
 
+import abc
 import collections
 import logging
 import os
@@ -53,6 +54,7 @@ import tempfile
 
 from apache_beam import pvalue
 from apache_beam.internal import pickler
+from apache_beam.pvalue import PCollection
 from apache_beam.runners import create_runner
 from apache_beam.runners import PipelineRunner
 from apache_beam.transforms import ptransform
@@ -157,6 +159,157 @@ class Pipeline(object):
     """Returns the root transform of the transform stack."""
     return self.transforms_stack[0]
 
+  def _remove_labels_recursively(self, applied_transform):
+    for part in applied_transform.parts:
+      if part.full_label in self.applied_labels:
+        self.applied_labels.remove(part.full_label)
+      if part.parts:
+        for part2 in part.parts:
+          self._remove_labels_recursively(part2)
+
+  def _replace(self, override):
+
+    assert isinstance(override, PTransformOverride)
+    matcher = override.get_matcher()
+
+    output_map = {}
+    output_replacements = {}
+    input_replacements = {}
+
+    class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignment
+      """"A visitor that replaces the matching PTransforms."""
+
+      def __init__(self, pipeline):
+        self.pipeline = pipeline
+
+      def _replace_if_needed(self, transform_node):
+        if matcher(transform_node):
+          replacement_transform = override.get_replacement_transform(
+              transform_node.transform)
+          inputs = transform_node.inputs
+          # TODO:  Support replacing PTransforms with multiple inputs.
+          if len(inputs) > 1:
+            raise NotImplementedError(
+                'PTransform overriding is only supported for PTransforms that '
+                'have a single input. Tried to replace input of '
+                'AppliedPTransform %r that has %d inputs',
+                transform_node, len(inputs))
+          transform_node.transform = replacement_transform
+          self.pipeline.transforms_stack.append(transform_node)
+
+          # Keeping the same label for the replaced node but recursively
+          # removing labels of child transforms since they will be replaced
+          # during the expand below.
+          self.pipeline._remove_labels_recursively(transform_node)
+
+          new_output = replacement_transform.expand(inputs[0])
+          if new_output.producer is None:
+            # When current transform is a primitive, we set the producer here.
+            new_output.producer = transform_node
+
+          # We only support replacing transforms with a single output with
+          # another transform that produces a single output.
+          # TODO: Support replacing PTransforms with multiple outputs.
+          if (len(transform_node.outputs) > 1 or
+              not isinstance(transform_node.outputs[None], PCollection) or
+              not isinstance(new_output, PCollection)):
+            raise NotImplementedError(
+                'PTransform overriding is only supported for PTransforms that '
+                'have a single output. Tried to replace output of '
+                'AppliedPTransform %r with %r.'
+                , transform_node, new_output)
+
+          # Recording updated outputs. This cannot be done in the same visitor
+          # since if we dynamically update output type here, we'll run into
+          # errors when visiting child nodes.
+          output_map[transform_node.outputs[None]] = new_output
+
+          self.pipeline.transforms_stack.pop()
+
+      def enter_composite_transform(self, transform_node):
+        self._replace_if_needed(transform_node)
+
+      def visit_transform(self, transform_node):
+        self._replace_if_needed(transform_node)
+
+    self.visit(TransformUpdater(self))
+
+    # Adjusting inputs and outputs
+    class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assignment
+      """"A visitor that records input and output values to be replaced.
+
+      Input and output values that should be updated are recorded in maps
+      input_replacements and output_replacements respectively.
+
+      We cannot update input and output values while visiting since that results
+      in validation errors.
+      """
+
+      def __init__(self, pipeline):
+        self.pipeline = pipeline
+
+      def enter_composite_transform(self, transform_node):
+        self.visit_transform(transform_node)
+
+      def visit_transform(self, transform_node):
+        if (None in transform_node.outputs and
+            transform_node.outputs[None] in output_map):
+          output_replacements[transform_node] = (
+              output_map[transform_node.outputs[None]])
+
+        replace_input = False
+        for input in transform_node.inputs:
+          if input in output_map:
+            replace_input = True
+            break
+
+        if replace_input:
+          new_input = [
+              input if not input in output_map else output_map[input]
+              for input in transform_node.inputs]
+          input_replacements[transform_node] = new_input
+
+    self.visit(InputOutputUpdater(self))
+
+    for transform in output_replacements:
+      transform.replace_output(output_replacements[transform])
+
+    for transform in input_replacements:
+      transform.inputs = input_replacements[transform]
+
+  def _check_replacement(self, override):
+    matcher = override.get_matcher()
+
+    class ReplacementValidator(PipelineVisitor):
+      def visit_transform(self, transform_node):
+        if matcher(transform_node):
+          raise RuntimeError('Transform node %r was not replaced as expected.',
+                             transform_node)
+
+    self.visit(ReplacementValidator())
+
+  def replace_all(self, replacements):
+    """ Dynamically replaces PTransforms in the currently populated hierarchy.
+
+     Currently this only works for replacements where input and output types
+     are exactly the same.
+     TODO: Update this to also work for transform overrides where input and
+     output types are different.
+
+    Args:
+      replacements a list of PTransformOverride objects.
+    """
+    for override in replacements:
+      assert isinstance(override, PTransformOverride)
+      self._replace(override)
+
+    # Checking if the PTransforms have been successfully replaced. This will
+    # result in a failure if a PTransform that was replaced in a given override
+    # gets re-added in a subsequent override. This is not allowed and ordering
+    # of PTransformOverride objects in 'replacements' is important.
+    for override in replacements:
+      self._check_replacement(override)
+
   def run(self, test_runner_api=True):
     """Runs the pipeline. Returns whatever our runner returns after running."""
 
@@ -441,6 +594,20 @@ class AppliedPTransform(object):
       for side_input in self.side_inputs:
         real_producer(side_input.pvalue).refcounts[side_input.pvalue.tag] += 1
 
+  def replace_output(self, output, tag=None):
+    """Replaces the output defined by the given tag with the given output.
+
+    Args:
+      output: replacement output
+      tag: tag of the output to be replaced.
+    """
+    if isinstance(output, pvalue.DoOutputsTuple):
+      self.replace_output(output[output._main_tag])
+    elif isinstance(output, pvalue.PValue):
+      self.outputs[tag] = output
+    else:
+      raise TypeError("Unexpected output type: %s" % output)
+
   def add_output(self, output, tag=None):
     if isinstance(output, pvalue.DoOutputsTuple):
       self.add_output(output[output._main_tag])
@@ -564,3 +731,37 @@ class AppliedPTransform(object):
           pc.tag = tag
     result.update_input_refcounts()
     return result
+
+
+class PTransformOverride(object):
+  """For internal use only; no backwards-compatibility guarantees.
+
+  Gives a matcher and replacements for matching PTransforms.
+
+  TODO: Update this to support cases where input and/our output types are
+  different.
+  """
+  __metaclass__ = abc.ABCMeta
+
+  @abc.abstractmethod
+  def get_matcher(self):
+    """Gives a matcher that will be used to to perform this override.
+
+    Returns:
+      a callable that takes an AppliedPTransform as a parameter and returns a
+      boolean as a result.
+    """
+    raise NotImplementedError
+
+  @abc.abstractmethod
+  def get_replacement_transform(self, ptransform):
+    """Provides a runner specific override for a given PTransform.
+
+    Args:
+      ptransform: PTransform to be replaced.
+    Returns:
+      A PTransform that will be the replacement for the PTransform given as an
+      argument.
+    """
+    # Returns a PTransformReplacement
+    raise NotImplementedError

http://git-wip-us.apache.org/repos/asf/beam/blob/33662b92/sdks/python/apache_beam/pipeline_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index e0775d1..f9b894f 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -28,9 +28,11 @@ import apache_beam as beam
 from apache_beam.io import Read
 from apache_beam.metrics import Metrics
 from apache_beam.pipeline import Pipeline
+from apache_beam.pipeline import PTransformOverride
 from apache_beam.pipeline import PipelineOptions
 from apache_beam.pipeline import PipelineVisitor
 from apache_beam.pvalue import AsSingleton
+from apache_beam.runners import DirectRunner
 from apache_beam.runners.dataflow.native_io.iobase import NativeSource
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
@@ -75,6 +77,18 @@ class FakeSource(NativeSource):
     return FakeSource._Reader(self._vals)
 
 
+class DoubleParDo(beam.PTransform):
+  def expand(self, input):
+    return input | 'Inner' >> beam.Map(lambda a: a * 2)
+
+
+class TripleParDo(beam.PTransform):
+  def expand(self, input):
+    # Keeping labels the same intentionally to make sure that there is no label
+    # conflict due to replacement.
+    return input | 'Inner' >> beam.Map(lambda a: a * 3)
+
+
 class PipelineTest(unittest.TestCase):
 
   @staticmethod
@@ -285,6 +299,27 @@ class PipelineTest(unittest.TestCase):
   #   p = Pipeline('EagerRunner')
   #   self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x))
 
+  def test_ptransform_overrides(self):
+
+    def my_par_do_matcher(applied_ptransform):
+      return isinstance(applied_ptransform.transform, DoubleParDo)
+
+    class MyParDoOverride(PTransformOverride):
+
+      def get_matcher(self):
+        return my_par_do_matcher
+
+      def get_replacement_transform(self, ptransform):
+        if isinstance(ptransform, DoubleParDo):
+          return TripleParDo()
+        raise ValueError('Unsupported type of transform: %r', ptransform)
+
+    # Using following private variable for testing.
+    DirectRunner._PTRANSFORM_OVERRIDES.append(MyParDoOverride())
+    with Pipeline() as p:
+      pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo()
+      assert_that(pcoll, equal_to([3, 6, 9]))
+
 
 class DoFnTest(unittest.TestCase):
 

http://git-wip-us.apache.org/repos/asf/beam/blob/33662b92/sdks/python/apache_beam/runners/direct/direct_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index ecf5114..323f44b 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -42,6 +42,14 @@ __all__ = ['DirectRunner']
 class DirectRunner(PipelineRunner):
   """Executes a single pipeline on the local machine."""
 
+  # A list of PTransformOverride objects to be applied before running a pipeline
+  # using DirectRunner.
+  # Currently this only works for overrides where the input and output types do
+  # not change.
+  # For internal SDK use only. This should not be updated by Beam pipeline
+  # authors.
+  _PTRANSFORM_OVERRIDES = []
+
   def __init__(self):
     self._cache = None
 
@@ -59,6 +67,9 @@ class DirectRunner(PipelineRunner):
   def run(self, pipeline):
     """Execute the entire pipeline and returns an DirectPipelineResult."""
 
+    # Performing configured PTransform overrides.
+    pipeline.replace_all(DirectRunner._PTRANSFORM_OVERRIDES)
+
     # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
     # with resolving imports when they are at top.
     # pylint: disable=wrong-import-position


Mime
View raw message