beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rober...@apache.org
Subject [1/2] beam git commit: Cleanup and fix ptransform_fn decorator.
Date Wed, 12 Jul 2017 01:08:30 GMT
Repository: beam
Updated Branches:
  refs/heads/master 84682109b -> 91c7d3d1f


Cleanup and fix ptransform_fn decorator.

Previously CallablePTransform was being used both as the
factory and the transform itself, which could result in state
getting carried between pipelines.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2b86a61e
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2b86a61e
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2b86a61e

Branch: refs/heads/master
Commit: 2b86a61e5bb07d3bd7f958e124bc8d79dc300c3f
Parents: 8468210
Author: Robert Bradshaw <robertwb@gmail.com>
Authored: Tue Jul 11 14:32:47 2017 -0700
Committer: Robert Bradshaw <robertwb@gmail.com>
Committed: Tue Jul 11 18:08:01 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/transforms/combiners.py |  8 ++++
 .../apache_beam/transforms/combiners_test.py    |  7 +---
 .../python/apache_beam/transforms/ptransform.py | 41 +++++++++-----------
 3 files changed, 28 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/2b86a61e/sdks/python/apache_beam/transforms/combiners.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py
index fa0742d..875306f 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -149,6 +149,7 @@ class Top(object):
   """Combiners for obtaining extremal elements."""
   # pylint: disable=no-self-argument
 
+  @staticmethod
   @ptransform.ptransform_fn
   def Of(pcoll, n, compare=None, *args, **kwargs):
     """Obtain a list of the compare-most N elements in a PCollection.
@@ -177,6 +178,7 @@ class Top(object):
     return pcoll | core.CombineGlobally(
         TopCombineFn(n, compare, key, reverse), *args, **kwargs)
 
+  @staticmethod
   @ptransform.ptransform_fn
   def PerKey(pcoll, n, compare=None, *args, **kwargs):
     """Identifies the compare-most N elements associated with each key.
@@ -210,21 +212,25 @@ class Top(object):
     return pcoll | core.CombinePerKey(
         TopCombineFn(n, compare, key, reverse), *args, **kwargs)
 
+  @staticmethod
   @ptransform.ptransform_fn
   def Largest(pcoll, n):
     """Obtain a list of the greatest N elements in a PCollection."""
     return pcoll | Top.Of(n)
 
+  @staticmethod
   @ptransform.ptransform_fn
   def Smallest(pcoll, n):
     """Obtain a list of the least N elements in a PCollection."""
     return pcoll | Top.Of(n, reverse=True)
 
+  @staticmethod
   @ptransform.ptransform_fn
   def LargestPerKey(pcoll, n):
     """Identifies the N greatest elements associated with each key."""
     return pcoll | Top.PerKey(n)
 
+  @staticmethod
   @ptransform.ptransform_fn
   def SmallestPerKey(pcoll, n, reverse=True):
     """Identifies the N least elements associated with each key."""
@@ -369,10 +375,12 @@ class Sample(object):
   """Combiners for sampling n elements without replacement."""
   # pylint: disable=no-self-argument
 
+  @staticmethod
   @ptransform.ptransform_fn
   def FixedSizeGlobally(pcoll, n):
     return pcoll | core.CombineGlobally(SampleCombineFn(n))
 
+  @staticmethod
   @ptransform.ptransform_fn
   def FixedSizePerKey(pcoll, n):
     return pcoll | core.CombinePerKey(SampleCombineFn(n))

http://git-wip-us.apache.org/repos/asf/beam/blob/2b86a61e/sdks/python/apache_beam/transforms/combiners_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py
index c79fec8..cd2b595 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -156,14 +156,11 @@ class CombineTest(unittest.TestCase):
 
   def test_combine_sample_display_data(self):
     def individual_test_per_key_dd(sampleFn, args, kwargs):
-      trs = [beam.CombinePerKey(sampleFn(*args, **kwargs)),
-             beam.CombineGlobally(sampleFn(*args, **kwargs))]
+      trs = [sampleFn(*args, **kwargs)]
       for transform in trs:
         dd = DisplayData.create_from(transform)
         expected_items = [
-            DisplayDataItemMatcher('fn', sampleFn.fn.__name__),
-            DisplayDataItemMatcher('combine_fn',
-                                   transform.fn.__class__)]
+            DisplayDataItemMatcher('fn', transform._fn.__name__)]
         if args:
           expected_items.append(
               DisplayDataItemMatcher('args', str(args)))

http://git-wip-us.apache.org/repos/asf/beam/blob/2b86a61e/sdks/python/apache_beam/transforms/ptransform.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 6041353..cd84122 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -595,32 +595,23 @@ class PTransformWithSideInputs(PTransform):
     return '%s(%s)' % (self.__class__.__name__, self.fn.default_label())
 
 
-class CallablePTransform(PTransform):
+class _PTransformFnPTransform(PTransform):
   """A class wrapper for a function-based transform."""
 
-  def __init__(self, fn):
-    # pylint: disable=super-init-not-called
-    # This  is a helper class for a function decorator. Only when the class
-    # is called (and __call__ invoked) we will have all the information
-    # needed to initialize the super class.
-    self.fn = fn
-    self._args = ()
-    self._kwargs = {}
+  def __init__(self, fn, *args, **kwargs):
+    super(_PTransformFnPTransform, self).__init__()
+    self._fn = fn
+    self._args = args
+    self._kwargs = kwargs
 
   def display_data(self):
-    res = {'fn': (self.fn.__name__
-                  if hasattr(self.fn, '__name__')
-                  else self.fn.__class__),
+    res = {'fn': (self._fn.__name__
+                  if hasattr(self._fn, '__name__')
+                  else self._fn.__class__),
            'args': DisplayDataItem(str(self._args)).drop_if_default('()'),
            'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')}
     return res
 
-  def __call__(self, *args, **kwargs):
-    super(CallablePTransform, self).__init__()
-    self._args = args
-    self._kwargs = kwargs
-    return self
-
   def expand(self, pcoll):
     # Since the PTransform will be implemented entirely as a function
     # (once called), we need to pass through any type-hinting information that
@@ -629,18 +620,18 @@ class CallablePTransform(PTransform):
     kwargs = dict(self._kwargs)
     args = tuple(self._args)
     try:
-      if 'type_hints' in inspect.getargspec(self.fn).args:
+      if 'type_hints' in inspect.getargspec(self._fn).args:
         args = (self.get_type_hints(),) + args
     except TypeError:
       # Might not be a function.
       pass
-    return self.fn(pcoll, *args, **kwargs)
+    return self._fn(pcoll, *args, **kwargs)
 
   def default_label(self):
     if self._args:
       return '%s(%s)' % (
-          label_from_callable(self.fn), label_from_callable(self._args[0]))
-    return label_from_callable(self.fn)
+          label_from_callable(self._fn), label_from_callable(self._args[0]))
+    return label_from_callable(self._fn)
 
 
 def ptransform_fn(fn):
@@ -684,7 +675,11 @@ def ptransform_fn(fn):
   operator (i.e., `|`) will inject the pcoll argument in its proper place
   (first argument if no label was specified and second argument otherwise).
   """
-  return CallablePTransform(fn)
+  # TODO(robertwb): Consider removing staticmethod to allow for self parameter.
+
+  def callable_ptransform_factory(*args, **kwargs):
+    return _PTransformFnPTransform(fn, *args, **kwargs)
+  return callable_ptransform_factory
 
 
 def label_from_callable(fn):


Mime
View raw message