Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 49D05200C12 for ; Sun, 22 Jan 2017 05:37:32 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 487F7160B56; Sun, 22 Jan 2017 04:37:32 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id C763C160B4A for ; Sun, 22 Jan 2017 05:37:30 +0100 (CET) Received: (qmail 10902 invoked by uid 500); 22 Jan 2017 04:37:29 -0000 Mailing-List: contact commits-help@beam.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@beam.apache.org Delivered-To: mailing list commits@beam.apache.org Received: (qmail 10887 invoked by uid 99); 22 Jan 2017 04:37:29 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 22 Jan 2017 04:37:29 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id E9568DFEF3; Sun, 22 Jan 2017 04:37:28 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: robertwb@apache.org To: commits@beam.apache.org Date: Sun, 22 Jan 2017 04:37:28 -0000 Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: [1/2] beam git commit: Implement Annotation based NewDoFn in python SDK archived-at: Sun, 22 Jan 2017 04:37:32 -0000 Repository: beam Updated Branches: refs/heads/python-sdk 946135f6a -> d0474ab5b Implement Annotation based NewDoFn in python SDK Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9e272ecf Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9e272ecf Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9e272ecf Branch: refs/heads/python-sdk Commit: 9e272ecf639b7b13f23a83868fd101a437159c1c Parents: 946135f Author: Sourabh Bajaj Authored: Fri Jan 20 17:17:25 2017 -0800 Committer: Robert Bradshaw Committed: Sat Jan 21 20:37:07 2017 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/pipeline_test.py | 100 ++++++++- sdks/python/apache_beam/runners/common.pxd | 4 + sdks/python/apache_beam/runners/common.py | 221 +++++++++++++------ .../runners/direct/transform_evaluator.py | 15 +- sdks/python/apache_beam/transforms/core.py | 113 +++++++++- sdks/python/apache_beam/typehints/decorators.py | 2 +- sdks/python/apache_beam/typehints/typecheck.py | 145 ++++++++++++ 7 files changed, 531 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/pipeline_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index 336bf54..93b68d1 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -24,15 +24,23 @@ import unittest from apache_beam.pipeline import Pipeline from apache_beam.pipeline import PipelineOptions from apache_beam.pipeline import PipelineVisitor +from apache_beam.pvalue import AsSingleton from apache_beam.runners.dataflow.native_io.iobase import NativeSource from apache_beam.test_pipeline import TestPipeline from apache_beam.transforms import CombineGlobally from apache_beam.transforms import Create from apache_beam.transforms import FlatMap from apache_beam.transforms import Map +from apache_beam.transforms import NewDoFn +from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform from apache_beam.transforms import Read -from apache_beam.transforms.util import assert_that, equal_to +from apache_beam.transforms import WindowInto +from apache_beam.transforms.util import assert_that +from apache_beam.transforms.util import equal_to +from apache_beam.transforms.window import IntervalWindow +from apache_beam.transforms.window import WindowFn +from apache_beam.utils.timestamp import MIN_TIMESTAMP class FakeSource(NativeSource): @@ -241,6 +249,96 @@ class PipelineTest(unittest.TestCase): self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x)) +class NewDoFnTest(unittest.TestCase): + + def setUp(self): + self.runner_name = 'DirectRunner' + + def test_element(self): + class TestDoFn(NewDoFn): + def process(self, element): + yield element + 10 + + pipeline = TestPipeline(runner=self.runner_name) + pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn()) + assert_that(pcoll, equal_to([11, 12])) + pipeline.run() + + def test_context_param(self): + class TestDoFn(NewDoFn): + def process(self, element, context=NewDoFn.ContextParam): + yield context.element + 10 + + pipeline = TestPipeline(runner=self.runner_name) + pcoll = pipeline | 'Create' >> Create([1, 2])| 'Do' >> ParDo(TestDoFn()) + assert_that(pcoll, equal_to([11, 12])) + pipeline.run() + + def test_side_input_no_tag(self): + class TestDoFn(NewDoFn): + def process(self, element, prefix, suffix): + return ['%s-%s-%s' % (prefix, element, suffix)] + + pipeline = TestPipeline() + words_list = ['aa', 'bb', 'cc'] + words = pipeline | 'SomeWords' >> Create(words_list) + prefix = 'zyx' + suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in + result = words | 'DecorateWordsDoFnNoTag' >> ParDo( + TestDoFn(), prefix, suffix=AsSingleton(suffix)) + assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list])) + pipeline.run() + + def test_side_input_tagged(self): + class TestDoFn(NewDoFn): + def process(self, element, prefix, suffix=NewDoFn.SideInputParam): + return ['%s-%s-%s' % (prefix, element, suffix)] + + pipeline = TestPipeline() + words_list = ['aa', 'bb', 'cc'] + words = pipeline | 'SomeWords' >> Create(words_list) + prefix = 'zyx' + suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in + result = words | 'DecorateWordsDoFnNoTag' >> ParDo( + TestDoFn(), prefix, suffix=AsSingleton(suffix)) + assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list])) + pipeline.run() + + def test_window_param(self): + class TestDoFn(NewDoFn): + def process(self, element, window=NewDoFn.WindowParam): + yield (float(window.start), float(window.end)) + + class TestWindowFn(WindowFn): + """Windowing function adding two disjoint windows to each element.""" + + def assign(self, assign_context): + _ = assign_context + return [IntervalWindow(10, 20), IntervalWindow(20, 30)] + + def merge(self, existing_windows): + return existing_windows + + pipeline = TestPipeline(runner=self.runner_name) + pcoll = (pipeline + | 'KVs' >> Create([(1, 10), (2, 20)]) + | 'W' >> WindowInto(windowfn=TestWindowFn()) + | 'Do' >> ParDo(TestDoFn())) + assert_that(pcoll, equal_to([(10.0, 20.0), (10.0, 20.0), + (20.0, 30.0), (20.0, 30.0)])) + pipeline.run() + + def test_timestamp_param(self): + class TestDoFn(NewDoFn): + def process(self, element, timestamp=NewDoFn.TimestampParam): + yield timestamp + + pipeline = TestPipeline(runner=self.runner_name) + pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn()) + assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP])) + pipeline.run() + + class Bacon(PipelineOptions): @classmethod http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/common.pxd ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd index 085fd11..06fe434 100644 --- a/sdks/python/apache_beam/runners/common.pxd +++ b/sdks/python/apache_beam/runners/common.pxd @@ -36,6 +36,10 @@ cdef class DoFnRunner(Receiver): cdef object tagged_receivers cdef LoggingContext logging_context cdef object step_name + cdef object is_new_dofn + cdef object args + cdef object kwargs + cdef object side_inputs cdef bint has_windowed_side_inputs cdef Receiver main_receivers http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/common.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index cc834ba..0f63cbc 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -71,50 +71,21 @@ class DoFnRunner(Receiver): # Preferred alternative to context # TODO(robertwb): Remove once all runners are updated. state=None): - self.has_windowed_side_inputs = False # Set to True in one case below. - if not args and not kwargs: - self.dofn = fn - self.dofn_process = fn.process - else: - global_window = window.GlobalWindow() - # TODO(robertwb): Remove when all runners pass side input maps. - side_inputs = [side_input - if isinstance(side_input, sideinputs.SideInputMap) - else {global_window: side_input} - for side_input in side_inputs] - if side_inputs and all( - isinstance(side_input, dict) or side_input.is_globally_windowed() - for side_input in side_inputs): - args, kwargs = util.insert_values_in_args( - args, kwargs, [side_input[global_window] - for side_input in side_inputs]) - side_inputs = [] - if side_inputs: - self.has_windowed_side_inputs = True - - def process(context): - w = context.windows[0] - cur_args, cur_kwargs = util.insert_values_in_args( - args, kwargs, [side_input[w] for side_input in side_inputs]) - return fn.process(context, *cur_args, **cur_kwargs) - self.dofn_process = process - elif kwargs: - self.dofn_process = lambda context: fn.process(context, *args, **kwargs) - else: - self.dofn_process = lambda context: fn.process(context, *args) - - class CurriedFn(core.DoFn): + self.step_name = step_name + self.window_fn = windowing.windowfn + self.tagged_receivers = tagged_receivers - start_bundle = staticmethod(fn.start_bundle) - process = staticmethod(self.dofn_process) - finish_bundle = staticmethod(fn.finish_bundle) + global_window = window.GlobalWindow() - self.dofn = CurriedFn() + if logging_context: + self.logging_context = logging_context + else: + self.logging_context = get_logging_context(logger, step_name=step_name) - self.window_fn = windowing.windowfn - self.tagged_receivers = tagged_receivers - self.step_name = step_name + # Optimize for the common case. + self.main_receivers = as_receiver(tagged_receivers[None]) + # TODO(sourabh): Deprecate the use of context if state: assert context is None self.context = DoFnContext(self.step_name, state=state) @@ -122,48 +93,172 @@ class DoFnRunner(Receiver): assert context is not None self.context = context - if logging_context: - self.logging_context = logging_context + # TODO(Sourabhbajaj): Remove the usage of OldDoFn + if isinstance(fn, core.NewDoFn): + self.is_new_dofn = True + + # SideInputs + self.side_inputs = [side_input + if isinstance(side_input, sideinputs.SideInputMap) + else {global_window: side_input} + for side_input in side_inputs] + self.has_windowed_side_inputs = not all( + isinstance(si, dict) or si.is_globally_windowed() + for si in self.side_inputs) + + self.args = args if args else [] + self.kwargs = kwargs if kwargs else {} + self.dofn = fn + else: - self.logging_context = get_logging_context(logger, step_name=step_name) + self.is_new_dofn = False + self.has_windowed_side_inputs = False # Set to True in one case below. + if not args and not kwargs: + self.dofn = fn + self.dofn_process = fn.process + else: + # TODO(robertwb): Remove when all runners pass side input maps. + side_inputs = [side_input + if isinstance(side_input, sideinputs.SideInputMap) + else {global_window: side_input} + for side_input in side_inputs] + if side_inputs and all( + isinstance(side_input, dict) or side_input.is_globally_windowed() + for side_input in side_inputs): + args, kwargs = util.insert_values_in_args( + args, kwargs, [side_input[global_window] + for side_input in side_inputs]) + side_inputs = [] + if side_inputs: + self.has_windowed_side_inputs = True + + def process(context): + w = context.windows[0] + cur_args, cur_kwargs = util.insert_values_in_args( + args, kwargs, [side_input[w] for side_input in side_inputs]) + return fn.process(context, *cur_args, **cur_kwargs) + self.dofn_process = process + elif kwargs: + self.dofn_process = lambda context: fn.process( + context, *args, **kwargs) + else: + self.dofn_process = lambda context: fn.process(context, *args) - # Optimize for the common case. - self.main_receivers = as_receiver(tagged_receivers[None]) + class CurriedFn(core.DoFn): + + start_bundle = staticmethod(fn.start_bundle) + process = staticmethod(self.dofn_process) + finish_bundle = staticmethod(fn.finish_bundle) + + self.dofn = CurriedFn() def receive(self, windowed_value): self.process(windowed_value) - def start(self): - self.context.set_element(None) + def old_dofn_process(self, element): + if self.has_windowed_side_inputs and len(element.windows) > 1: + for w in element.windows: + self.context.set_element( + WindowedValue(element.value, element.timestamp, (w,))) + self._process_outputs(element, self.dofn_process(self.context)) + else: + self.context.set_element(element) + self._process_outputs(element, self.dofn_process(self.context)) + + def new_dofn_process(self, element): + self.context.set_element(element) + arguments, _, _, defaults = self.dofn.get_function_arguments('process') + defaults = defaults if defaults else [] + + self_in_args = int(self.dofn.is_process_bounded()) + + # Call for the process function for each window if has windowed side inputs + # or if the process accesses the window parameter. We can just call it once + # otherwise as none of the arguments are changing + if self.has_windowed_side_inputs or core.NewDoFn.WindowParam in defaults: + windows = element.windows + else: + windows = [window.GlobalWindow()] + + for w in windows: + args, kwargs = util.insert_values_in_args( + self.args, self.kwargs, + [s[w] for s in self.side_inputs]) + + # If there are more arguments than the default then the first argument + # should be the element and the rest should be picked from the side + # inputs as window and timestamp should always be tagged + if len(arguments) > len(defaults) + self_in_args: + if core.NewDoFn.ElementParam not in defaults: + args_to_pick = len(arguments) - len(defaults) - 1 - self_in_args + final_args = [element.value] + args[:args_to_pick] + else: + args_to_pick = len(arguments) - len(defaults) - self_in_args + final_args = args[:args_to_pick] + else: + args_to_pick = 0 + final_args = [] + args = iter(args[args_to_pick:]) + + for a, d in zip(arguments[-len(defaults):], defaults): + if d == core.NewDoFn.ElementParam: + final_args.append(element.value) + elif d == core.NewDoFn.ContextParam: + final_args.append(self.context) + elif d == core.NewDoFn.WindowParam: + final_args.append(w) + elif d == core.NewDoFn.TimestampParam: + final_args.append(element.timestamp) + elif d == core.NewDoFn.SideInputParam: + # If no more args are present then the value must be passed via kwarg + try: + final_args.append(args.next()) + except StopIteration: + if a not in kwargs: + raise + else: + # If no more args are present then the value must be passed via kwarg + try: + final_args.append(args.next()) + except StopIteration: + if a not in kwargs: + kwargs[a] = d + final_args.extend(list(args)) + self._process_outputs(element, self.dofn.process(*final_args, **kwargs)) + + def _invoke_bundle_method(self, method): try: self.logging_context.enter() - self._process_outputs(None, self.dofn.start_bundle(self.context)) + self.context.set_element(None) + f = getattr(self.dofn, method) + + # TODO(Sourabhbajaj): Remove this if-else + if self.is_new_dofn: + _, _, _, defaults = self.dofn.get_function_arguments(method) + defaults = defaults if defaults else [] + args = [self.context if d == core.NewDoFn.ContextParam else d + for d in defaults] + self._process_outputs(None, f(*args)) + else: + self._process_outputs(None, f(self.context)) except BaseException as exn: self.reraise_augmented(exn) finally: self.logging_context.exit() + def start(self): + self._invoke_bundle_method('start_bundle') + def finish(self): - self.context.set_element(None) - try: - self.logging_context.enter() - self._process_outputs(None, self.dofn.finish_bundle(self.context)) - except BaseException as exn: - self.reraise_augmented(exn) - finally: - self.logging_context.exit() + self._invoke_bundle_method('finish_bundle') def process(self, element): try: self.logging_context.enter() - if self.has_windowed_side_inputs and len(element.windows) > 1: - for w in element.windows: - self.context.set_element( - WindowedValue(element.value, element.timestamp, (w,))) - self._process_outputs(element, self.dofn_process(self.context)) + if self.is_new_dofn: + self.new_dofn_process(element) else: - self.context.set_element(element) - self._process_outputs(element, self.dofn_process(self.context)) + self.old_dofn_process(element) except BaseException as exn: self.reraise_augmented(exn) finally: http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/direct/transform_evaluator.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py index b4c43ba..ec2b3a1 100644 --- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py @@ -35,8 +35,10 @@ from apache_beam.transforms import sideinputs from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn +from apache_beam.typehints.typecheck import OutputCheckWrapperNewDoFn from apache_beam.typehints.typecheck import TypeCheckError from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn +from apache_beam.typehints.typecheck import TypeCheckWrapperNewDoFn from apache_beam.utils import counters from apache_beam.utils.pipeline_options import TypeOptions @@ -344,9 +346,18 @@ class _ParDoEvaluator(_TransformEvaluator): pipeline_options = self._evaluation_context.pipeline_options if (pipeline_options is not None and pipeline_options.view_as(TypeOptions).runtime_type_check): - dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints()) + # TODO(sourabhbajaj): Remove this if-else + if isinstance(dofn, core.NewDoFn): + dofn = TypeCheckWrapperNewDoFn(dofn, transform.get_type_hints()) + else: + dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints()) - dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label) + # TODO(sourabhbajaj): Remove this if-else + if isinstance(dofn, core.NewDoFn): + dofn = OutputCheckWrapperNewDoFn( + dofn, self._applied_ptransform.full_label) + else: + dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label) self.runner = DoFnRunner(dofn, transform.args, transform.kwargs, self._side_inputs, self._applied_ptransform.inputs[0].windowing, http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/transforms/core.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 72f7cd4..70a03ae 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -109,6 +109,7 @@ class DoFnProcessContext(DoFnContext): self.timestamp = windowed_value.timestamp self.windows = windowed_value.windows + # TODO(sourabhbajaj): Move as we're trying to deprecate the use of context def aggregate_to(self, aggregator, input_value): """Provide a new input value for the aggregator. @@ -119,6 +120,112 @@ class DoFnProcessContext(DoFnContext): self.state.counter_for(aggregator).update(input_value) +class NewDoFn(WithTypeHints, HasDisplayData): + """A function object used by a transform with custom processing. + + The ParDo transform is such a transform. The ParDo.apply + method will take an object of type DoFn and apply it to all elements of a + PCollection object. + + In order to have concrete DoFn objects one has to subclass from DoFn and + define the desired behavior (start_bundle/finish_bundle and process) or wrap a + callable object using the CallableWrapperDoFn class. + """ + + ElementParam = 'ElementParam' + ContextParam = 'ContextParam' + SideInputParam = 'SideInputParam' + TimestampParam = 'TimestampParam' + WindowParam = 'WindowParam' + + @staticmethod + def from_callable(fn): + return CallableWrapperDoFn(fn) + + def default_label(self): + return self.__class__.__name__ + + def process(self, element, *args, **kwargs): + """Called for each element of a pipeline. The default arguments are needed + for the DoFnRunner to be able to pass the parameters correctly. + + Args: + element: The element to be processed + context: a DoFnProcessContext object containing. See the + DoFnProcessContext documentation for details. + *args: side inputs + **kwargs: keyword side inputs + """ + raise NotImplementedError + + def start_bundle(self): + """Called before a bundle of elements is processed on a worker. + + Elements to be processed are split into bundles and distributed + to workers. Before a worker calls process() on the first element + of its bundle, it calls this method. + """ + pass + + def finish_bundle(self): + """Called after a bundle of elements is processed on a worker. + """ + pass + + def get_function_arguments(self, func): + """Return the function arguments based on the name provided. If they have + a _inspect_function attached to the class then use that otherwise default + to the python inspect library. + """ + func_name = '_inspect_%s' % func + if hasattr(self, func_name): + f = getattr(self, func_name) + return f() + else: + f = getattr(self, func) + return inspect.getargspec(f) + + # TODO(sourabhbajaj): Do we want to remove the responsiblity of these from + # the DoFn or maybe the runner + def infer_output_type(self, input_type): + # TODO(robertwb): Side inputs types. + # TODO(robertwb): Assert compatibility with input type hint? + return self._strip_output_annotations( + trivial_inference.infer_return_type(self.process, [input_type])) + + def _strip_output_annotations(self, type_hint): + annotations = (window.TimestampedValue, window.WindowedValue, + pvalue.SideOutputValue) + # TODO(robertwb): These should be parameterized types that the + # type inferencer understands. + if (type_hint in annotations + or trivial_inference.element_type(type_hint) in annotations): + return Any + else: + return type_hint + + def process_argspec_fn(self): + """Returns the Python callable that will eventually be invoked. + + This should ideally be the user-level function that is called with + the main and (if any) side inputs, and is used to relate the type + hint parameters with the input parameters (e.g., by argument name). + """ + return self.process + + def is_process_bounded(self): + """Checks if an object is a bound method on an instance.""" + if not isinstance(self.process, types.MethodType): + return False # Not a method + if self.process.im_self is None: + return False # Method is not bound + if issubclass(self.process.im_class, type) or \ + self.process.im_class is types.ClassType: + return False # Method is a classmethod + return True + + +# TODO(Sourabh): Remove after migration to NewDoFn class DoFn(WithTypeHints, HasDisplayData): """A function object used by a transform with custom processing. @@ -577,7 +684,7 @@ class ParDo(PTransformWithSideInputs): def __init__(self, fn_or_label, *args, **kwargs): super(ParDo, self).__init__(fn_or_label, *args, **kwargs) - if not isinstance(self.fn, DoFn): + if not isinstance(self.fn, (DoFn, NewDoFn)): raise TypeError('ParDo must be called with a DoFn instance.') def default_type_hints(self): @@ -588,7 +695,9 @@ class ParDo(PTransformWithSideInputs): self.fn.infer_output_type(input_type)) def make_fn(self, fn): - return fn if isinstance(fn, DoFn) else CallableWrapperDoFn(fn) + if isinstance(fn, (DoFn, NewDoFn)): + return fn + return CallableWrapperDoFn(fn) def process_argspec_fn(self): return self.fn.process_argspec_fn() http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/typehints/decorators.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index a300a3b..df15f1b 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -263,7 +263,7 @@ def getcallargs_forhints(func, *typeargs, **typekwargs): for k, var in enumerate(reversed(argspec.args)): if k >= len(argspec.defaults): break - if callargs.get(var, None) is argspec.defaults[-k]: + if callargs.get(var, None) is argspec.defaults[-k-1]: callargs[var] = typehints.Any # Patch up varargs and keywords if argspec.varargs: http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/typehints/typecheck.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/typehints/typecheck.py b/sdks/python/apache_beam/typehints/typecheck.py index d759d55..7a10a5a 100644 --- a/sdks/python/apache_beam/typehints/typecheck.py +++ b/sdks/python/apache_beam/typehints/typecheck.py @@ -24,6 +24,7 @@ import types from apache_beam.pvalue import SideOutputValue from apache_beam.transforms.core import DoFn +from apache_beam.transforms.core import NewDoFn from apache_beam.transforms.window import WindowedValue from apache_beam.typehints import check_constraint from apache_beam.typehints import CompositeTypeHintError @@ -162,3 +163,147 @@ class OutputCheckWrapperDoFn(DoFn): 'iterable. %s was returned instead.' % type(output)) return output + + +class AbstractDoFnWrapper(NewDoFn): + """An abstract class to create wrapper around NewDoFn""" + + def __init__(self, dofn): + super(AbstractDoFnWrapper, self).__init__() + self.dofn = dofn + + def _inspect_start_bundle(self): + return self.dofn.get_function_arguments('start_bundle') + + def _inspect_process(self): + return self.dofn.get_function_arguments('process') + + def _inspect_finish_bundle(self): + return self.dofn.get_function_arguments('finish_bundle') + + def wrapper(self, method, args, kwargs): + return method(*args, **kwargs) + + def start_bundle(self, *args, **kwargs): + return self.wrapper(self.dofn.start_bundle, args, kwargs) + + def process(self, *args, **kwargs): + return self.wrapper(self.dofn.process, args, kwargs) + + def finish_bundle(self, *args, **kwargs): + return self.wrapper(self.dofn.finish_bundle, args, kwargs) + + def is_process_bounded(self): + return self.dofn.is_process_bounded() + + +class OutputCheckWrapperNewDoFn(AbstractDoFnWrapper): + """A DoFn that verifies against common errors in the output type.""" + + def __init__(self, dofn, full_label): + super(OutputCheckWrapperNewDoFn, self).__init__(dofn) + self.full_label = full_label + + def wrapper(self, method, args, kwargs): + try: + result = method(*args, **kwargs) + except TypeCheckError as e: + error_msg = ('Runtime type violation detected within ParDo(%s): ' + '%s' % (self.full_label, e)) + raise TypeCheckError, error_msg, sys.exc_info()[2] + else: + return self._check_type(result) + + def _check_type(self, output): + if output is None: + return output + elif isinstance(output, (dict, basestring)): + object_type = type(output).__name__ + raise TypeCheckError('Returning a %s from a ParDo or FlatMap is ' + 'discouraged. Please use list("%s") if you really ' + 'want this behavior.' % + (object_type, output)) + elif not isinstance(output, collections.Iterable): + raise TypeCheckError('FlatMap and ParDo must return an ' + 'iterable. %s was returned instead.' + % type(output)) + return output + + +class TypeCheckWrapperNewDoFn(AbstractDoFnWrapper): + """A wrapper around a DoFn which performs type-checking of input and output. + """ + + def __init__(self, dofn, type_hints, label=None): + super(TypeCheckWrapperNewDoFn, self).__init__(dofn) + self.dofn = dofn + self._process_fn = self.dofn.process_argspec_fn() + if type_hints.input_types: + input_args, input_kwargs = type_hints.input_types + self._input_hints = getcallargs_forhints( + self._process_fn, *input_args, **input_kwargs) + else: + self._input_hints = None + # TODO(robertwb): Multi-output. + self._output_type_hint = type_hints.simple_output_type(label) + + def wrapper(self, method, args, kwargs): + result = method(*args, **kwargs) + return self._type_check_result(result) + + def process(self, *args, **kwargs): + if self._input_hints: + actual_inputs = inspect.getcallargs(self._process_fn, *args, **kwargs) + for var, hint in self._input_hints.items(): + if hint is actual_inputs[var]: + # self parameter + continue + _check_instance_type(hint, actual_inputs[var], var, True) + return self._type_check_result(self.dofn.process(*args, **kwargs)) + + def _type_check_result(self, transform_results): + if self._output_type_hint is None or transform_results is None: + return transform_results + + def type_check_output(o): + # TODO(robertwb): Multi-output. + x = o.value if isinstance(o, (SideOutputValue, WindowedValue)) else o + self._type_check(self._output_type_hint, x, is_input=False) + + # If the return type is a generator, then we will need to interleave our + # type-checking with its normal iteration so we don't deplete the + # generator initially just by type-checking its yielded contents. + if isinstance(transform_results, types.GeneratorType): + return GeneratorWrapper(transform_results, type_check_output) + else: + for o in transform_results: + type_check_output(o) + return transform_results + + def _type_check(self, type_constraint, datum, is_input): + """Typecheck a PTransform related datum according to a type constraint. + + This function is used to optionally type-check either an input or an output + to a PTransform. + + Args: + type_constraint: An instance of a typehints.TypeContraint, one of the + white-listed builtin Python types, or a custom user class. + datum: An instance of a Python object. + is_input: True if 'datum' is an input to a PTransform's DoFn. False + otherwise. + + Raises: + TypeError: If 'datum' fails to type-check according to 'type_constraint'. + """ + datum_type = 'input' if is_input else 'output' + + try: + check_constraint(type_constraint, datum) + except CompositeTypeHintError as e: + raise TypeCheckError, e.message, sys.exc_info()[2] + except SimpleTypeHintError: + error_msg = ("According to type-hint expected %s should be of type %s. " + "Instead, received '%s', an instance of type %s." + % (datum_type, type_constraint, datum, type(datum))) + raise TypeCheckError, error_msg, sys.exc_info()[2]