beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pabl...@apache.org
Subject [beam] branch master updated: Merge pull request #14869 from [BEAM-12357] improve WithKeys transform to take args, kwargs
Date Tue, 13 Jul 2021 04:56:57 GMT
This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f9b9ccc  Merge pull request #14869 from [BEAM-12357] improve WithKeys transform to
take args, kwargs
f9b9ccc is described below

commit f9b9ccc64bc44867be46b65af0e35e8287e19be4
Author: heidimhurst <32763215+heidimhurst@users.noreply.github.com>
AuthorDate: Mon Jul 12 22:56:16 2021 -0600

    Merge pull request #14869 from [BEAM-12357] improve WithKeys transform to take args, kwargs
    
    * [BEAM-12357] improve WithKeys transform to take args, kwargs
    
    This commit extends the existing functionality of WithKeys to accept
    positional and keyword arguments, consistent with the use of the Map
    function.  This allows inputs to be passed into the keyword creation
    function. Example use: utils.WithKeys(key_fn, foo, kwarg1=bar) would
    pass variables foo and bar into key_fn when each key is created.
    
    * [BEAM-12357] PR feedback: move fn_takes_side_inputs to utils
    
    * Update sdks/python/apache_beam/transforms/util.py
    
    update WithKeys to pass in *args, **kwargs to internal Map function, not just lambda
    
    Co-authored-by: Pablo <pabloem@users.noreply.github.com>
    
    * Revert "Update sdks/python/apache_beam/transforms/util.py
    "
    
    This reverts commit a5a654860684f2978140f63054f1f391163f4b7c.
    
    * Preventing circular import in core.py
    
    * allow WithKeys to take side inputs
    
    Expand handling of args, kwargs within WithKeys PTransform to
    include side inputs as well as non-pcollection inputs. This
    includes the following changes:
    - additional if case in WithKeys checking for AsSideInput
    - additional test case for AsSideInput inputs
    - adds AsSideInput as visible class in pvalue.py
    
    * fix lint
    
    Co-authored-by: Pablo <pabloem@users.noreply.github.com>
---
 sdks/python/apache_beam/pvalue.py               |  1 +
 sdks/python/apache_beam/transforms/core.py      | 16 ++-----------
 sdks/python/apache_beam/transforms/util.py      | 32 +++++++++++++++++++++++--
 sdks/python/apache_beam/transforms/util_test.py | 26 ++++++++++++++++++++
 4 files changed, 59 insertions(+), 16 deletions(-)

diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py
index e17cfab..2b593e4 100644
--- a/sdks/python/apache_beam/pvalue.py
+++ b/sdks/python/apache_beam/pvalue.py
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
 __all__ = [
     'PCollection',
     'TaggedOutput',
+    'AsSideInput',
     'AsSingleton',
     'AsIter',
     'AsList',
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 74c5b74..cb69e64 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -698,19 +698,6 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_DOFN)
 
 
-def _fn_takes_side_inputs(fn):
-  try:
-    signature = get_signature(fn)
-  except TypeError:
-    # We can't tell; maybe it does.
-    return True
-
-  return (
-      len(signature.parameters) > 1 or any(
-          p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD
-          for p in signature.parameters.values()))
-
-
 class CallableWrapperDoFn(DoFn):
   """For internal use only; no backwards-compatibility guarantees.
 
@@ -1564,7 +1551,8 @@ def Map(fn, *args, **kwargs):  # pylint: disable=invalid-name
     raise TypeError(
         'Map can be used only with callable objects. '
         'Received %r instead.' % (fn))
-  if _fn_takes_side_inputs(fn):
+  from apache_beam.transforms.util import fn_takes_side_inputs
+  if fn_takes_side_inputs(fn):
     wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)]
   else:
     wrapper = lambda x: [fn(x)]
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index d00da29..28024fa 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -40,6 +40,7 @@ from apache_beam import typehints
 from apache_beam.metrics import Metrics
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.pvalue import AsSideInput
 from apache_beam.transforms import window
 from apache_beam.transforms.combiners import CountCombineFn
 from apache_beam.transforms.core import CombinePerKey
@@ -63,6 +64,7 @@ from apache_beam.transforms.userstate import on_timer
 from apache_beam.transforms.window import NonMergingWindowFn
 from apache_beam.transforms.window import TimestampCombiner
 from apache_beam.transforms.window import TimestampedValue
+from apache_beam.typehints.decorators import get_signature
 from apache_beam.typehints.sharded_key_type import ShardedKeyType
 from apache_beam.utils import windowed_value
 from apache_beam.utils.annotations import deprecated
@@ -741,14 +743,40 @@ class Reshuffle(PTransform):
     return Reshuffle()
 
 
+def fn_takes_side_inputs(fn):
+  try:
+    signature = get_signature(fn)
+  except TypeError:
+    # We can't tell; maybe it does.
+    return True
+
+  return (
+      len(signature.parameters) > 1 or any(
+          p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD
+          for p in signature.parameters.values()))
+
+
 @ptransform_fn
-def WithKeys(pcoll, k):
+def WithKeys(pcoll, k, *args, **kwargs):
   """PTransform that takes a PCollection, and either a constant key or a
   callable, and returns a PCollection of (K, V), where each of the values in
   the input PCollection has been paired with either the constant key or a key
-  computed from the value.
+  computed from the value.  The callable may optionally accept positional or
+  keyword arguments, which should be passed to WithKeys directly.  These may
+  be either SideInputs or static (non-PCollection) values, such as ints.
   """
   if callable(k):
+    if fn_takes_side_inputs(k):
+      if all([isinstance(arg, AsSideInput)
+              for arg in args]) and all([isinstance(kwarg, AsSideInput)
+                                         for kwarg in kwargs.values()]):
+        return pcoll | Map(
+            lambda v,
+            *args,
+            **kwargs: (k(v, *args, **kwargs), v),
+            *args,
+            **kwargs)
+      return pcoll | Map(lambda v: (k(v, *args, **kwargs), v))
     return pcoll | Map(lambda v: (k(v), v))
   return pcoll | Map(lambda v: (k, v))
 
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index a283e19..94cc8f3 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -38,6 +38,8 @@ from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.pvalue import AsList
+from apache_beam.pvalue import AsSingleton
 from apache_beam.runners import pipeline_context
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
@@ -628,6 +630,30 @@ class WithKeysTest(unittest.TestCase):
       with_keys = pc | util.WithKeys(lambda x: x * x)
     assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)]))
 
+  @staticmethod
+  def _test_args_kwargs_fn(x, multiply, subtract):
+    return x * multiply - subtract
+
+  def test_args_kwargs_k(self):
+    with TestPipeline() as p:
+      pc = p | beam.Create(self.l)
+      with_keys = pc | util.WithKeys(
+          WithKeysTest._test_args_kwargs_fn, 2, subtract=1)
+    assert_that(with_keys, equal_to([(1, 1), (3, 2), (5, 3)]))
+
+  def test_sideinputs(self):
+    with TestPipeline() as p:
+      pc = p | beam.Create(self.l)
+      si1 = AsList(p | "side input 1" >> beam.Create([1, 2, 3]))
+      si2 = AsSingleton(p | "side input 2" >> beam.Create([10]))
+      with_keys = pc | util.WithKeys(
+          lambda x,
+          the_list,
+          the_singleton: x + sum(the_list) + the_singleton,
+          si1,
+          the_singleton=si2)
+    assert_that(with_keys, equal_to([(17, 1), (18, 2), (19, 3)]))
+
 
 class GroupIntoBatchesTest(unittest.TestCase):
   NUM_ELEMENTS = 10

Mime
View raw message