spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gurwls...@apache.org
Subject spark git commit: [SPARK-23754][PYTHON] Re-raising StopIteration in client code
Date Wed, 30 May 2018 10:11:39 GMT
Repository: spark
Updated Branches:
  refs/heads/master a4be981c0 -> 0ebb0c0d4


[SPARK-23754][PYTHON] Re-raising StopIteration in client code

## What changes were proposed in this pull request?

Make sure that `StopIteration`s raised in users' code do not silently interrupt processing
by spark, but are raised as exceptions to the users. The users' functions are wrapped in `safe_iter`
(in `shuffle.py`), which re-raises `StopIteration`s as `RuntimeError`s

## How was this patch tested?

Unit tests, making sure that the exceptions are indeed raised. I am not sure how to check
whether a `Py4JJavaError` contains my exception, so I simply looked for the exception message
in the java exception's `toString`. Can you propose a better way?

## License

This is my original work, licensed in the same way as spark

Author: e-dorigatti <emilio.dorigatti@gmail.com>
Author: edorigatti <emilio.dorigatti@gmail.com>

Closes #21383 from e-dorigatti/fix_spark_23754.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0ebb0c0d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0ebb0c0d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0ebb0c0d

Branch: refs/heads/master
Commit: 0ebb0c0d4dd3e192464dc5e0e6f01efa55b945ed
Parents: a4be981
Author: e-dorigatti <emilio.dorigatti@gmail.com>
Authored: Wed May 30 18:11:33 2018 +0800
Committer: hyukjinkwon <gurwls223@apache.org>
Committed: Wed May 30 18:11:33 2018 +0800

----------------------------------------------------------------------
 python/pyspark/rdd.py       | 18 +++++++++++---
 python/pyspark/shuffle.py   |  7 +++---
 python/pyspark/sql/tests.py | 16 ++++++++++++
 python/pyspark/sql/udf.py   | 14 +++++++++--
 python/pyspark/tests.py     | 53 ++++++++++++++++++++++++++++++++++++++++
 python/pyspark/util.py      | 28 ++++++++++++++++++---
 6 files changed, 125 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0ebb0c0d/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d5a237a..14d9128 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -53,6 +53,7 @@ from pyspark.resultiterable import ResultIterable
 from pyspark.shuffle import Aggregator, ExternalMerger, \
     get_used_memory, ExternalSorter, ExternalGroupBy
 from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.util import fail_on_stopiteration
 
 
 __all__ = ["RDD"]
@@ -339,7 +340,7 @@ class RDD(object):
         [('a', 1), ('b', 1), ('c', 1)]
         """
         def func(_, iterator):
-            return map(f, iterator)
+            return map(fail_on_stopiteration(f), iterator)
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
     def flatMap(self, f, preservesPartitioning=False):
@@ -354,7 +355,7 @@ class RDD(object):
         [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
         """
         def func(s, iterator):
-            return chain.from_iterable(map(f, iterator))
+            return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
     def mapPartitions(self, f, preservesPartitioning=False):
@@ -417,7 +418,7 @@ class RDD(object):
         [2, 4]
         """
         def func(iterator):
-            return filter(f, iterator)
+            return filter(fail_on_stopiteration(f), iterator)
         return self.mapPartitions(func, True)
 
     def distinct(self, numPartitions=None):
@@ -798,6 +799,8 @@ class RDD(object):
         >>> def f(x): print(x)
         >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
         """
+        f = fail_on_stopiteration(f)
+
         def processPartition(iterator):
             for x in iterator:
                 f(x)
@@ -847,6 +850,8 @@ class RDD(object):
             ...
         ValueError: Can not reduce() empty RDD
         """
+        f = fail_on_stopiteration(f)
+
         def func(iterator):
             iterator = iter(iterator)
             try:
@@ -918,6 +923,8 @@ class RDD(object):
         >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
         15
         """
+        op = fail_on_stopiteration(op)
+
         def func(iterator):
             acc = zeroValue
             for obj in iterator:
@@ -950,6 +957,9 @@ class RDD(object):
         >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
         (0, 0)
         """
+        seqOp = fail_on_stopiteration(seqOp)
+        combOp = fail_on_stopiteration(combOp)
+
         def func(iterator):
             acc = zeroValue
             for obj in iterator:
@@ -1643,6 +1653,8 @@ class RDD(object):
         >>> sorted(rdd.reduceByKeyLocally(add).items())
         [('a', 2), ('b', 1)]
         """
+        func = fail_on_stopiteration(func)
+
         def reducePartition(iterator):
             m = {}
             for k, v in iterator:

http://git-wip-us.apache.org/repos/asf/spark/blob/0ebb0c0d/python/pyspark/shuffle.py
----------------------------------------------------------------------
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 02c7733..bd0ac00 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -28,6 +28,7 @@ import sys
 import pyspark.heapq3 as heapq
 from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer,
\
     CompressedSerializer, AutoBatchedSerializer
+from pyspark.util import fail_on_stopiteration
 
 
 try:
@@ -94,9 +95,9 @@ class Aggregator(object):
     """
 
     def __init__(self, createCombiner, mergeValue, mergeCombiners):
-        self.createCombiner = createCombiner
-        self.mergeValue = mergeValue
-        self.mergeCombiners = mergeCombiners
+        self.createCombiner = fail_on_stopiteration(createCombiner)
+        self.mergeValue = fail_on_stopiteration(mergeValue)
+        self.mergeCombiners = fail_on_stopiteration(mergeCombiners)
 
 
 class SimpleAggregator(Aggregator):

http://git-wip-us.apache.org/repos/asf/spark/blob/0ebb0c0d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index c7bd8f0..a245093 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -900,6 +900,22 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(f, f_.func)
         self.assertEqual(return_type, f_.returnType)
 
+    def test_stopiteration_in_udf(self):
+        # test for SPARK-23754
+        from pyspark.sql.functions import udf
+        from py4j.protocol import Py4JJavaError
+
+        def foo(x):
+            raise StopIteration()
+
+        with self.assertRaises(Py4JJavaError) as cm:
+            self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()
+
+        self.assertIn(
+            "Caught StopIteration thrown from user's code; failing the task",
+            cm.exception.java_exception.toString()
+        )
+
     def test_validate_column_types(self):
         from pyspark.sql.functions import udf, to_json
         from pyspark.sql.column import _to_java_column

http://git-wip-us.apache.org/repos/asf/spark/blob/0ebb0c0d/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 9dbe49b..c8fb49d 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -25,7 +25,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
     to_arrow_type, to_arrow_schema
-from pyspark.util import _get_argspec
+from pyspark.util import _get_argspec, fail_on_stopiteration
 
 __all__ = ["UDFRegistration"]
 
@@ -157,7 +157,17 @@ class UserDefinedFunction(object):
         spark = SparkSession.builder.getOrCreate()
         sc = spark.sparkContext
 
-        wrapped_func = _wrap_function(sc, self.func, self.returnType)
+        func = fail_on_stopiteration(self.func)
+
+        # for pandas UDFs the worker needs to know if the function takes
+        # one or two arguments, but the signature is lost when wrapping with
+        # fail_on_stopiteration, so we store it here
+        if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+                             PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+                             PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
+            func._argspec = _get_argspec(self.func)
+
+        wrapped_func = _wrap_function(sc, func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
             self._name, wrapped_func, jdt, self.evalType, self.deterministic)

http://git-wip-us.apache.org/repos/asf/spark/blob/0ebb0c0d/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 498d6b5..3b37cc0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -161,6 +161,37 @@ class MergerTests(unittest.TestCase):
             self.assertEqual(k, len(vs))
             self.assertEqual(list(range(k)), list(vs))
 
+    def test_stopiteration_is_raised(self):
+
+        def stopit(*args, **kwargs):
+            raise StopIteration()
+
+        def legit_create_combiner(x):
+            return [x]
+
+        def legit_merge_value(x, y):
+            return x.append(y) or x
+
+        def legit_merge_combiners(x, y):
+            return x.extend(y) or x
+
+        data = [(x % 2, x) for x in range(100)]
+
+        # wrong create combiner
+        m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners),
20)
+        with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+            m.mergeValues(data)
+
+        # wrong merge value
+        m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners),
20)
+        with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+            m.mergeValues(data)
+
+        # wrong merge combiners
+        m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit),
20)
+        with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+            m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
+
 
 class SorterTests(unittest.TestCase):
     def test_in_memory_sort(self):
@@ -1246,6 +1277,28 @@ class RDDTests(ReusedPySparkTestCase):
         result = rdd.pipe('cat').collect()
         self.assertEqual(data, result)
 
+    def test_stopiteration_in_client_code(self):
+
+        def stopit(*x):
+            raise StopIteration()
+
+        seq_rdd = self.sc.parallelize(range(10))
+        keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
+
+        self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
+        self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
+        self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
+        self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
+        self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
+        self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
+        self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)
+
+        # the exception raised is non-deterministic
+        self.assertRaises((Py4JJavaError, RuntimeError),
+                          seq_rdd.aggregate, 0, stopit, lambda *x: 1)
+        self.assertRaises((Py4JJavaError, RuntimeError),
+                          seq_rdd.aggregate, 0, lambda *x: 1, stopit)
+
 
 class ProfilerTests(PySparkTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0ebb0c0d/python/pyspark/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 59cc2a6..e95a9b5 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -53,11 +53,16 @@ def _get_argspec(f):
     """
     Get argspec of a function. Supports both Python 2 and Python 3.
     """
-    # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
-    # See SPARK-23569.
-    if sys.version_info[0] < 3:
+
+    if hasattr(f, '_argspec'):
+        # only used for pandas UDF: they wrap the user function, losing its signature
+        # workers need this signature, so UDF saves it here
+        argspec = f._argspec
+    elif sys.version_info[0] < 3:
         argspec = inspect.getargspec(f)
     else:
+        # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
+        # See SPARK-23569.
         argspec = inspect.getfullargspec(f)
     return argspec
 
@@ -89,6 +94,23 @@ class VersionUtils(object):
                              " version numbers.")
 
 
+def fail_on_stopiteration(f):
+    """
+    Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
+    prevents silent loss of data when 'f' is used in a for loop
+    """
+    def wrapper(*args, **kwargs):
+        try:
+            return f(*args, **kwargs)
+        except StopIteration as exc:
+            raise RuntimeError(
+                "Caught StopIteration thrown from user's code; failing the task",
+                exc
+            )
+
+    return wrapper
+
+
 if __name__ == "__main__":
     import doctest
     (failure_count, test_count) = doctest.testmod()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message