beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rober...@apache.org
Subject [1/2] beam git commit: [BEAM-1925] validate DoFn at pipeline creation time
Date Tue, 09 May 2017 06:21:47 GMT
Repository: beam
Updated Branches:
  refs/heads/master 844762d10 -> d96fd173c


[BEAM-1925] validate DoFn at pipeline creation time


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9283a5e8
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9283a5e8
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9283a5e8

Branch: refs/heads/master
Commit: 9283a5e8f931f554b507739f4448c862caa7e5cd
Parents: 844762d
Author: Sourabh Bajaj <sourabhbajaj@google.com>
Authored: Mon May 8 13:33:15 2017 -0700
Committer: Robert Bradshaw <robertwb@gmail.com>
Committed: Mon May 8 23:21:34 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/runners/common.py       | 25 ++++-----
 sdks/python/apache_beam/runners/common_test.py  | 58 ++++++++++++++++++++
 sdks/python/apache_beam/transforms/core.py      |  9 ++-
 .../apache_beam/transforms/ptransform_test.py   |  1 -
 4 files changed, 78 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/9283a5e8/sdks/python/apache_beam/runners/common.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 045c109..74c61ab 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -95,17 +95,20 @@ class DoFnSignature(object):
     self._validate()
 
   def _validate(self):
+    self._validate_process()
     self._validate_bundle_method(self.start_bundle_method)
     self._validate_bundle_method(self.finish_bundle_method)
 
-  def _validate_bundle_method(self, method_wrapper):
-    # Here we use the fact that every DoFn parameter defined in core.DoFn has
-    # the value that is the same as the name of the parameter and ends with
-    # string 'Param'.
-    unsupported_dofn_params = [i for i in core.DoFn.__dict__ if
-                               i.endswith('Param')]
+  def _validate_process(self):
+    """Validate that none of the DoFnParameters are repeated in the function
+    """
+    for param in core.DoFn.DoFnParams:
+      assert self.process_method.defaults.count(param) <= 1
 
-    for param in unsupported_dofn_params:
+  def _validate_bundle_method(self, method_wrapper):
+    """Validate that none of the DoFnParameters are used in the function
+    """
+    for param in core.DoFn.DoFnParams:
       assert param not in method_wrapper.defaults
 
 
@@ -156,18 +159,14 @@ class DoFnInvoker(object):
   def invoke_start_bundle(self):
     """Invokes the DoFn.start_bundle() method.
     """
-    args_for_start_bundle = self.signature.start_bundle_method.defaults
     self.output_processor.start_bundle_outputs(
-        self.signature.start_bundle_method.method_value(
-            *args_for_start_bundle))
+        self.signature.start_bundle_method.method_value())
 
   def invoke_finish_bundle(self):
     """Invokes the DoFn.finish_bundle() method.
     """
-    args_for_finish_bundle = self.signature.finish_bundle_method.defaults
     self.output_processor.finish_bundle_outputs(
-        self.signature.finish_bundle_method.method_value(
-            *args_for_finish_bundle))
+        self.signature.finish_bundle_method.method_value())
 
 
 class SimpleInvoker(DoFnInvoker):

http://git-wip-us.apache.org/repos/asf/beam/blob/9283a5e8/sdks/python/apache_beam/runners/common_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/common_test.py b/sdks/python/apache_beam/runners/common_test.py
new file mode 100644
index 0000000..62a6955
--- /dev/null
+++ b/sdks/python/apache_beam/runners/common_test.py
@@ -0,0 +1,58 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from apache_beam.transforms.core import DoFn
+from apache_beam.runners.common import DoFnSignature
+
+
+class DoFnSignatureTest(unittest.TestCase):
+
+  def test_dofn_validate_process_error(self):
+    class MyDoFn(DoFn):
+      def process(self, element, w1=DoFn.WindowParam, w2=DoFn.WindowParam):
+        pass
+
+    with self.assertRaises(AssertionError):
+      DoFnSignature(MyDoFn())
+
+  def test_dofn_validate_start_bundle_error(self):
+    class MyDoFn(DoFn):
+      def process(self, element):
+        pass
+
+      def start_bundle(self, w1=DoFn.WindowParam):
+        pass
+
+    with self.assertRaises(AssertionError):
+      DoFnSignature(MyDoFn())
+
+  def test_dofn_validate_finish_bundle_error(self):
+    class MyDoFn(DoFn):
+      def process(self, element):
+        pass
+
+      def finish_bundle(self, w1=DoFn.WindowParam):
+        pass
+
+    with self.assertRaises(AssertionError):
+      DoFnSignature(MyDoFn())
+
+
+if __name__ == '__main__':
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/9283a5e8/sdks/python/apache_beam/transforms/core.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 7ca1632..e37a387 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -29,7 +29,8 @@ from apache_beam.coders import typecoders
 from apache_beam.internal import util
 from apache_beam.runners.api import beam_runner_api_pb2
 from apache_beam.transforms import ptransform
-from apache_beam.transforms.display import HasDisplayData, DisplayDataItem
+from apache_beam.transforms.display import DisplayDataItem
+from apache_beam.transforms.display import HasDisplayData
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.ptransform import PTransformWithSideInputs
 from apache_beam.transforms.window import MIN_TIMESTAMP
@@ -131,6 +132,8 @@ class DoFn(WithTypeHints, HasDisplayData):
   TimestampParam = 'TimestampParam'
   WindowParam = 'WindowParam'
 
+  DoFnParams = [ElementParam, SideInputParam, TimestampParam, WindowParam]
+
   @staticmethod
   def from_callable(fn):
     return CallableWrapperDoFn(fn)
@@ -596,6 +599,10 @@ class ParDo(PTransformWithSideInputs):
     if not isinstance(self.fn, DoFn):
       raise TypeError('ParDo must be called with a DoFn instance.')
 
+    # Validate the DoFn by creating a DoFnSignature
+    from apache_beam.runners.common import DoFnSignature
+    DoFnSignature(self.fn)
+
   def default_type_hints(self):
     return self.fn.get_type_hints()
 

http://git-wip-us.apache.org/repos/asf/beam/blob/9283a5e8/sdks/python/apache_beam/transforms/ptransform_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py
index b8b0733..e712661 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -303,7 +303,6 @@ class PTransformTest(unittest.TestCase):
 
       def start_bundle(self):
         self.state = 'started'
-        return None
 
       def process(self, element):
         if self.state == 'started':


Mime
View raw message